# Isotropic state distillation - multi copy


In [1]:
import time
import random
import math

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


import quairkit as qkit
from quairkit import Circuit
from quairkit import to_state
from quairkit.database import *
from quairkit.loss import *
from quairkit.qinfo import *
from quairkit.database.hamiltonian import ising_hamiltonian
from quairkit.ansatz import *
from quairkit.operator import ParamOracle

qkit.set_dtype('complex128')

## Dynamic LOCCNet

In [None]:
def dynloccnetcir(n):
    cir = Circuit(2*n)
    cir.universal_qudits(qubits_idx=[0,2,4,6])
    cir.universal_qudits(qubits_idx=[1,3,5,7])

    return cir

## 5-copy

In [None]:
def loss_func_dyn_5(cir1, cir2, target_state, noisy_state2):
    
    input_state1 = torch.kron(torch.kron(torch.kron(noisy_state2,noisy_state2),noisy_state2),noisy_state2)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,0,[4,2**6]).density_matrix
        
    input_state2 = torch.kron(output_state1,noisy_state2)
    state2 = cir2(to_state(input_state2,eps=None))
    measure_state2 = Measure('z'* 6)
    _, m_state2 = measure_state2(state2, qubits_idx=list(range(6)),keep_state=True,desired_result='0'*6)
    output_state = partial_trace(m_state2,0,[2**6,4]).density_matrix

    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn_5(num_itr, LR, n, target_state,noisy_state2):
    
    loss_list, time_list = [], []
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
       
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) 
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5) 
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5) 
       
    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()

        loss, output_state3,_ = loss_func_dyn_5(cir1,cir2, target_state, noisy_state2) 
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        scheduler1.step(loss) 
        scheduler2.step(loss)
 
            
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []
    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## 6-copy

In [None]:
def loss_func_dyn_6(cir1, cir2, cir3, target_state, noisy_state2):
    
    input_state1 = torch.kron(torch.kron(torch.kron(noisy_state2,noisy_state2),noisy_state2),noisy_state2)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,0,[4,2**6]).density_matrix
        
    input_state2 = torch.kron(output_state1,noisy_state2)
    state2 = cir2(to_state(input_state2,eps=None))
    _, m_state2 = measure_state(state2, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state2 = partial_trace(m_state2,0,[4,2**6]).density_matrix
    
    input_state3 = torch.kron(output_state2,noisy_state2)
    state3 = cir3(to_state(input_state3,eps=None))
    measure_state2 = Measure('z'* 6)
    _, m_state3 = measure_state2(state3, qubits_idx=list(range(6)),keep_state=True,desired_result='0'*6)
    output_state = partial_trace(m_state3,0,[2**6,4]).density_matrix

    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn_6(num_itr, LR, n, target_state,noisy_state):
    
    loss_list, time_list = [], []
    
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
    cir3 = dynloccnetcir(n)
    
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) 
    opt_cir3 = torch.optim.Adam(lr=LR, params=cir3.parameters()) 
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5) 
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5) 
    scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir3, 'min', factor=0.5) 

    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()
        opt_cir3.zero_grad()
 
        loss, output_state3,_ = loss_func_dyn_6(cir1, cir2, cir3, target_state, noisy_state)
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        opt_cir3.step()

        scheduler1.step(loss) 
        scheduler2.step(loss)
        scheduler3.step(loss)
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## 7-copy

In [None]:
def loss_func_dyn_7(cir1, cir2, cir3, cir4, target_state, noisy_state2):
    
    input_state1 = torch.kron(torch.kron(torch.kron(noisy_state2,noisy_state2),noisy_state2),noisy_state2)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,0,[4,2**6]).density_matrix
        
    input_state2 = torch.kron(output_state1,noisy_state2)
    state2 = cir2(to_state(input_state2,eps=None))
    _, m_state2 = measure_state(state2, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state2 = partial_trace(m_state2,0,[4,2**6]).density_matrix
    
    input_state3 = torch.kron(output_state2,noisy_state2)
    state3 = cir3(to_state(input_state3,eps=None))
    _, m_state3 = measure_state(state3, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state3 = partial_trace(m_state3,0,[4,2**6]).density_matrix
        
    input_state4 = torch.kron(output_state3,noisy_state2)
    state4 = cir4(to_state(input_state4,eps=None))
    measure_state2 = Measure('z'* 6)
    _, m_state4 = measure_state2(state4, qubits_idx=list(range(6)),keep_state=True,desired_result='0'*6)
    output_state = partial_trace(m_state4,0,[2**6,4]).density_matrix
 
    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn_7(num_itr, LR, n, target_state,noisy_state):
    
    loss_list, time_list = [], []
    
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
    cir3 = dynloccnetcir(n)
    cir4 = dynloccnetcir(n)

    
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) 
    opt_cir3 = torch.optim.Adam(lr=LR, params=cir3.parameters())
    opt_cir4 = torch.optim.Adam(lr=LR, params=cir4.parameters()) 

    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5)
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5)
    scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir3, 'min', factor=0.5) 
    scheduler4 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir4, 'min', factor=0.5) 
  
    
    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()
        opt_cir3.zero_grad()
        opt_cir4.zero_grad()
        
        loss, output_state3,_ = loss_func_dyn_7(cir1, cir2, cir3, cir4, target_state, noisy_state) 
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        opt_cir3.step()
        opt_cir4.step()
        scheduler1.step(loss)
        scheduler2.step(loss)
        scheduler3.step(loss)
        scheduler4.step(loss)
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## 8-copy

In [None]:
def loss_func_dyn_8(cir1, cir2, cir3, cir4, cir5, target_state, noisy_state2):
    
    input_state1 = torch.kron(torch.kron(torch.kron(noisy_state2,noisy_state2),noisy_state2),noisy_state2)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,0,[4,2**6]).density_matrix
        
    input_state2 = torch.kron(output_state1,noisy_state2)
    state2 = cir2(to_state(input_state2,eps=None))
    _, m_state2 = measure_state(state2, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state2 = partial_trace(m_state2,0,[4,2**6]).density_matrix
    
    input_state3 = torch.kron(output_state2,noisy_state2)
    state3 = cir3(to_state(input_state3,eps=None))
    _, m_state3 = measure_state(state3, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state3 = partial_trace(m_state3,0,[4,2**6]).density_matrix
    
    input_state4 = torch.kron(output_state3,noisy_state2)
    state4 = cir4(to_state(input_state4,eps=None))
    _, m_state4 = measure_state(state4, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state4 = partial_trace(m_state4,0,[4,2**6]).density_matrix
        
    input_state5 = torch.kron(output_state4,noisy_state2)
    state5 = cir5(to_state(input_state5,eps=None))
    measure_state2 = Measure('z'* 6)
    _, m_state5 = measure_state2(state5, qubits_idx=list(range(6)),keep_state=True,desired_result='0'*6)
    output_state = partial_trace(m_state5,0,[2**6,4]).density_matrix
 
    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn_8(num_itr, LR, n, target_state,noisy_state):
    
    loss_list, time_list = [], []
    
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
    cir3 = dynloccnetcir(n)
    cir4 = dynloccnetcir(n)
    cir5 = dynloccnetcir(n)
    
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5)
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) 
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5) 
    opt_cir3 = torch.optim.Adam(lr=LR, params=cir3.parameters()) 
    scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir3, 'min', factor=0.5) 
    opt_cir4 = torch.optim.Adam(lr=LR, params=cir4.parameters()) 
    scheduler4 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir4, 'min', factor=0.5)
    opt_cir5 = torch.optim.Adam(lr=LR, params=cir5.parameters()) 
    scheduler5 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir5, 'min', factor=0.5) 

    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()
        opt_cir3.zero_grad()
        opt_cir4.zero_grad()
        opt_cir5.zero_grad()
        
        loss, output_state3,_ = loss_func_dyn_8(cir1, cir2, cir3, cir4, cir5, target_state, noisy_state)
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        opt_cir3.step()
        opt_cir4.step()
        opt_cir5.step()

        scheduler1.step(loss) 
        scheduler2.step(loss)
        scheduler3.step(loss)
        scheduler4.step(loss)
        scheduler5.step(loss)
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## 9-copy

In [None]:
def loss_func_dyn_9(cir1, cir2, cir3, cir4, cir5, cir6, target_state, noisy_state2):
    
    input_state1 = torch.kron(torch.kron(torch.kron(noisy_state2,noisy_state2),noisy_state2),noisy_state2)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,0,[4,2**6]).density_matrix
        
    input_state2 = torch.kron(output_state1,noisy_state2)
    state2 = cir2(to_state(input_state2,eps=None))
    _, m_state2 = measure_state(state2, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state2 = partial_trace(m_state2,0,[4,2**6]).density_matrix
    
    input_state3 = torch.kron(output_state2,noisy_state2)
    state3 = cir3(to_state(input_state3,eps=None))
    _, m_state3 = measure_state(state3, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state3 = partial_trace(m_state3,0,[4,2**6]).density_matrix
    
    input_state4 = torch.kron(output_state3,noisy_state2)
    state4 = cir4(to_state(input_state4,eps=None))
    _, m_state4 = measure_state(state4, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state4 = partial_trace(m_state4,0,[4,2**6]).density_matrix
    
    input_state5 = torch.kron(output_state4,noisy_state2)
    state5 = cir5(to_state(input_state5,eps=None))
    _, m_state5 = measure_state(state5, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state5 = partial_trace(m_state5,0,[4,2**6]).density_matrix
        
    input_state6 = torch.kron(output_state5,noisy_state2)
    state6 = cir6(to_state(input_state6,eps=None))
    measure_state2 = Measure('z'* 6)
    _, m_state6 = measure_state2(state6, qubits_idx=list(range(6)),keep_state=True,desired_result='0'*6)
    output_state = partial_trace(m_state6,0,[2**6,4]).density_matrix
 
    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn_9(num_itr, LR, n, target_state,noisy_state):
    
    loss_list, time_list = [], []
    
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
    cir3 = dynloccnetcir(n)
    cir4 = dynloccnetcir(n)
    cir5 = dynloccnetcir(n)
    cir6 = dynloccnetcir(n)
    
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5)
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) 
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5)
    opt_cir3 = torch.optim.Adam(lr=LR, params=cir3.parameters()) 
    scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir3, 'min', factor=0.5)
    opt_cir4 = torch.optim.Adam(lr=LR, params=cir4.parameters()) 
    scheduler4 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir4, 'min', factor=0.5) 
    opt_cir5 = torch.optim.Adam(lr=LR, params=cir5.parameters()) 
    scheduler5 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir5, 'min', factor=0.5) 
    opt_cir6 = torch.optim.Adam(lr=LR, params=cir6.parameters()) 
    scheduler6 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir6, 'min', factor=0.5) 
    
    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()
        opt_cir3.zero_grad()
        opt_cir4.zero_grad()
        opt_cir5.zero_grad()
        opt_cir6.zero_grad()
        
        loss, output_state3,_ = loss_func_dyn_9(cir1, cir2, cir3, cir4, cir5, cir6, target_state, noisy_state) 
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        opt_cir3.step()
        opt_cir4.step()
        opt_cir5.step()
        opt_cir6.step()

        scheduler1.step(loss)
        scheduler2.step(loss)
        scheduler3.step(loss)
        scheduler4.step(loss)
        scheduler5.step(loss)
        scheduler6.step(loss)
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## 10-copy

In [None]:
def loss_func_dyn_10(cir1, cir2, cir3, cir4, cir5, cir6, cir7, target_state, noisy_state2):
    
    input_state1 = torch.kron(torch.kron(torch.kron(noisy_state2,noisy_state2),noisy_state2),noisy_state2)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,0,[4,2**6]).density_matrix
        
    input_state2 = torch.kron(output_state1,noisy_state2)
    state2 = cir2(to_state(input_state2,eps=None))
    _, m_state2 = measure_state(state2, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state2 = partial_trace(m_state2,0,[4,2**6]).density_matrix
    
    input_state3 = torch.kron(output_state2,noisy_state2)
    state3 = cir3(to_state(input_state3,eps=None))
    _, m_state3 = measure_state(state3, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state3 = partial_trace(m_state3,0,[4,2**6]).density_matrix
    
    input_state4 = torch.kron(output_state3,noisy_state2)
    state4 = cir4(to_state(input_state4,eps=None))
    _, m_state4 = measure_state(state4, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state4 = partial_trace(m_state4,0,[4,2**6]).density_matrix
    
    input_state5 = torch.kron(output_state4,noisy_state2)
    state5 = cir5(to_state(input_state5,eps=None))
    _, m_state5 = measure_state(state5, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state5 = partial_trace(m_state5,0,[4,2**6]).density_matrix
    
    input_state6 = torch.kron(output_state5,noisy_state2)
    state6 = cir6(to_state(input_state6,eps=None))
    _, m_state6 = measure_state(state6, qubits_idx=list(range(2)),keep_state=True,desired_result='0'*2)
    output_state6 = partial_trace(m_state6,0,[4,2**6]).density_matrix
        
    input_state7 = torch.kron(output_state6,noisy_state2)
    state7 = cir7(to_state(input_state7,eps=None))
    measure_state2 = Measure('z'* 6)
    _, m_state7 = measure_state2(state7, qubits_idx=list(range(6)),keep_state=True,desired_result='0'*6)
    output_state = partial_trace(m_state7,0,[2**6,4]).density_matrix
 
    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn_10(num_itr, LR, n, target_state,noisy_state):
    
    loss_list, time_list = [], []
    
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
    cir3 = dynloccnetcir(n)
    cir4 = dynloccnetcir(n)
    cir5 = dynloccnetcir(n)
    cir6 = dynloccnetcir(n)
    cir7 = dynloccnetcir(n)
    
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5) 
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) 
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5) 
    opt_cir3 = torch.optim.Adam(lr=LR, params=cir3.parameters())
    scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir3, 'min', factor=0.5)
    opt_cir4 = torch.optim.Adam(lr=LR, params=cir4.parameters()) 
    scheduler4 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir4, 'min', factor=0.5) 
    opt_cir5 = torch.optim.Adam(lr=LR, params=cir5.parameters()) 
    scheduler5 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir5, 'min', factor=0.5) 
    opt_cir6 = torch.optim.Adam(lr=LR, params=cir6.parameters()) 
    scheduler6 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir6, 'min', factor=0.5) 
    opt_cir7 = torch.optim.Adam(lr=LR, params=cir7.parameters()) 
    scheduler7 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir7, 'min', factor=0.5) 

    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()
        opt_cir3.zero_grad()
        opt_cir4.zero_grad()
        opt_cir5.zero_grad()
        opt_cir6.zero_grad()
        opt_cir7.zero_grad()
        
        loss, output_state3,_ = loss_func_dyn_10(cir1, cir2, cir3, cir4, cir5, cir6, cir7, target_state, noisy_state) 
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        opt_cir3.step()
        opt_cir4.step()
        opt_cir5.step()
        opt_cir6.step()
        opt_cir7.step()

        scheduler1.step(loss) 
        scheduler2.step(loss)
        scheduler3.step(loss)
        scheduler4.step(loss)
        scheduler5.step(loss)
        scheduler6.step(loss)
        scheduler7.step(loss)
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## LOCCNet

In [15]:
def loccnetcir(m):
        
    A_list = []
    B_list = []
    for i in range(m):
        a = 2 * i
        b = 2 * i + 1
        A_list.append(a)
        B_list.append(b)
        
    cir = Circuit(2*m)
    cir.universal_qudits(qubits_idx=A_list)
    cir.universal_qudits(qubits_idx=B_list)
    
    return cir

In [16]:
def loss_func_loccnet(cir1, m, target_state, noisy_state2):
    
    input_state1 = nkron(*[noisy_state2 for _ in range(m)])
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* (2*m-2))
    _, m_state = measure_state(state1, qubits_idx=list(range(2*m-2)),keep_state=True,desired_result='0'*(2*m-2))
    output_state = partial_trace(m_state,0,[2**(2*m-2),2**2]).density_matrix

    loss = 1 - state_fidelity(target_state,output_state)**2
    f = state_fidelity(target_state,output_state).item()**2
    
    return loss, output_state,f

In [None]:
def train_model_loccnet(num_itr, LR, m, target_state,noisy_state):
    
    loss_list, time_list = [], []
    cir1 = loccnetcir(m)
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) 
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5) 

    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        loss, output_state3,_ = loss_func_loccnet(cir1, m, target_state, noisy_state) 
        loss.backward()
        opt_cir1.step()

        scheduler1.step(loss) 
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

# Train

In [None]:
n = 4 # qubits number of dloccnet per round
m = 4 # qubit number of loccnet
NUM_ITR = 2000 # iteration
LR = 0.1 # learning rate
target_state = bell_state(2).density_matrix # target state

fid_dynamic = []
fid_loccnet = []
fid_iso1 = []

p1 = 8 # 10 times of noise parameter 
iso_state = (p1/10) * bell_state(2).density_matrix + (1-p1/10) * eye(4)/4
fid1 = state_fidelity(iso_state,target_state).item()**2
fid_iso1.append(fid1)

## 4 copies of isotropic state, dloccnet is equal to loccnet in this case
# f4 = train_model_loccnet(NUM_ITR, LR, 4, target_state,iso_state)
# fid_dynamic.append(f4)
# fid_loccnet.append(f4)

## dloccnet
### 5 copies of isotropic
# fdyn5 = train_model_dyn_5(NUM_ITR, LR, n, target_state,iso_state)
# fid_dynamic.append(fdyn5)
### 6 copies of isotropic
# fdyn6 = train_model_dyn_6(NUM_ITR, LR, n, target_state,iso_state)
# fid_dynamic.append(fdyn6)
### 7 copies of isotropic
# fdyn7 = train_model_dyn_7(NUM_ITR, LR, n, target_state,iso_state)
# fid_dynamic.append(fdyn7)
### 8 copies of isotropic
# fdyn8 = train_model_dyn_8(NUM_ITR, LR, n, target_state,iso_state)
# fid_dynamic.append(fdyn8)
### 9 copies of isotropic
# fdyn9 = train_model_dyn_9(NUM_ITR, LR, n, target_state,iso_state)
# fid_dynamic.append(fdyn9)
### 10 copies of isotropic
fdyn10 = train_model_dyn_10(NUM_ITR, LR, n, target_state,iso_state)
fid_dynamic.append(fdyn10)

## dloccnet
### 5 copies of isotropic
# f5 = train_model_loccnet(NUM_ITR, LR, m, target_state,iso_state)
# fid_loccnet.append(f5)
### 6 copies of isotropic
f6 = train_model_loccnet(NUM_ITR, LR, m, target_state,iso_state)
fid_loccnet.append(f6)

print('10-copy of dyamic 4-3',fid_dynamic)
print('loccnet',fid_loccnet)