#### Code by Nicolas Perez

#### MIT license. Much appreciated if: credit is given to Nicolas Perez, a reference to the "A Framework for Authenticity, Integrity and Replay Protection in QuantumData Communication" paper is made explicit, and credit is given to all of the authors of the paper. 

### The code in this notebook is for the replay detection simulation in the "A Framework for Authenticity, Integrity and Replay Protection in QuantumData Communication" paper.

In [None]:
import numpy as np
from math import sqrt
import cmath
from scipy.stats import unitary_group
import random
import scipy.stats
import os
from tabulate import tabulate

In [None]:
def depolarizing_channel(lambda_param, qubit):
    return lambda_param * qubit + ((1 - lambda_param)/2)*np.identity(2)

def depolarize_states(lambda_param, states):
    new_states = []
    for state in states:
        new_states.append(depolarizing_channel(lambda_param, state))
    return new_states

def normal_distribution_confidence_interval(data):
    data = 1.0 * np.array(data)
    mean = np.mean(data)
    sem = scipy.stats.sem(data)
    diff = sem * 1.96
    m_str = "%.2f" % round(mean, 2)
    diff_str = "%.2f" % round(diff, 2)
    confidence_interval_string = "$" + m_str + "\pm" + diff_str + "$"
    print(confidence_interval_string)
    return mean-diff, mean, mean+diff, confidence_interval_string

def generate_random_unitary_matrix():
    return unitary_group.rvs(2)

def generate_pairs_of_random_unitaries_and_states(num_states):
    unitaries_list = []
    states_list = []
    for i in range(num_states):
        random_unitary_matrix = generate_random_unitary_matrix()
        unitaries_list.append(random_unitary_matrix)
        state_vector = np.matmul(random_unitary_matrix, np.array([0,1]))
        density_matrix_state = np.outer(state_vector, np.conj(state_vector))
        states_list.append(density_matrix_state)
    return unitaries_list, states_list


def wilson_distribution_confidence_interval(successful_trials, total_trials, alpha):
    z = scipy.stats.norm.ppf(1 - alpha)
    n = total_trials
    if n == 0:
        return 0
    phat = float(successful_trials) / n
    return ((phat + z*z/(2*n) - z * sqrt((phat*(1-phat)+z*z/(4*n))/n))/(1+z*z/n)), phat

def estimate_channel_fidelity(unitaries, depolarized_states, alpha):
    depolarized_states_to_measure = []
    for i in range(len(unitaries)):
        unitary = unitaries[i]
        depolarized_state = depolarized_states[i]
        inverse_operation = np.matrix.getH(unitary)
        depolarized_states_to_measure.append(np.matmul(np.matmul(inverse_operation,depolarized_state),unitary))
    
    total_states = len(depolarized_states)
    total_correct_measurements = 0
    measurements_list = []
    for depolarized_state_to_measure in depolarized_states_to_measure:
        probability_of_1 = depolarized_state_to_measure[1,1]
        rand_num = random.uniform(0, 1)
        if rand_num <= probability_of_1:
            total_correct_measurements = total_correct_measurements + 1
            measurements_list.append(1)
        else:
            measurements_list.append(0)
    
    lower_bound, _ = wilson_distribution_confidence_interval(total_correct_measurements, total_states, 
                                                             alpha)

    return lower_bound, total_correct_measurements/total_states, measurements_list, total_correct_measurements

Test to make sure average quantum channel fidelity is being estimated properly. Modify these values to see! The 'depolarize_states' method can be replaced with other quantum channel simulations.

In [None]:
num_states_used_to_estimate_fidelity = 500
lambda_param = 0.9
alpha = 0.05

unitaries, states = generate_pairs_of_random_unitaries_and_states(num_states_used_to_estimate_fidelity)
depolarized_states = depolarize_states(lambda_param, states)
lower_bound, average, measurements_list, num_correct_measurements = estimate_channel_fidelity(unitaries, depolarized_states, 
                                                                                              alpha)

In [None]:
print(lower_bound, average, num_correct_measurements)

In the associated paper, channels are monitored by keeping track of the last $j$ received states for estimating average channel fidelity. Significant changes in fidelity correspond to an adversary attacking the channel. The adversary introduces disturbances on the channel given by the 'lambda_2' parameter of the below 'enqueue_measurement_result' method. 'lambda_1' and 'lambda_3' correspond to the channel(s) used by the source and destination.

In [None]:
def enqueue_measurement_result(lambda_1, lambda_2, lambda_3, measurements_list, num_correct_measurements,
                                   unitaries, states, current_state_and_unitary_index):
    num_correct_measurements = num_correct_measurements - measurements_list.pop(0)
    current_state_and_unitary_index = current_state_and_unitary_index % len(unitaries)
    random_unitary = unitaries[current_state_and_unitary_index]
    random_state = states[current_state_and_unitary_index]
    
    depolarized_state = depolarizing_channel(lambda_1, random_state)
    
    depolarized_state = depolarizing_channel(lambda_2, depolarized_state)
    
    depolarized_state = depolarizing_channel(lambda_3, depolarized_state)
    
    inverse_operation = np.matrix.getH(random_unitary)
    state_to_measure = np.matmul(np.matmul(inverse_operation,depolarized_state),random_unitary)
    
    probability_of_1 = state_to_measure[1,1]
    rand_num = random.uniform(0, 1)
    if rand_num <= probability_of_1:
        measurements_list.append(1)
        num_correct_measurements = num_correct_measurements + 1
    else:
        measurements_list.append(0)
    
    return num_correct_measurements/len(measurements_list), measurements_list, num_correct_measurements
    

Test code showing how an adversary can reduce estimated average channel fidelity.

In [None]:
for i in range(50):
    estimation, measurements_list, num_correct_measurements = enqueue_measurement_result(0.9, 0.7, 1, measurements_list, 
                                                                                         num_correct_measurements, unitaries, 
                                                                                         states, i)
    print(estimation)

A method for simulating how long it takes the current average channel fidelity estimate of the last $j$ transmitted qubits to drop below the initially estimated lower bound.

In [None]:
def get_qubits_needed_to_detect_possible_replay(num_trials, num_states_used_to_estimate_fidelity, lambda_1, lambda_2, lambda_3, 
                                                alpha):
    total_num_qubits_needed_across_all_trials = 0
    num_qubits_needed_across_all_trials = []
    for i in range(num_trials):
        unitaries, states = generate_pairs_of_random_unitaries_and_states(num_states_used_to_estimate_fidelity)
        depolarized_states = depolarize_states(lambda_3, depolarize_states(lambda_1, states))
        lower_bound, _, measurements_list, num_correct_measurements = estimate_channel_fidelity(unitaries, depolarized_states, 
                                                                                                alpha)
        estimation = 1
        num_qubits_needed = 0
        while estimation > lower_bound:
            num_qubits_needed = num_qubits_needed + 1
            estimation, measurements_list, num_correct_measurements = enqueue_measurement_result(lambda_1, lambda_2, lambda_3,
                                                                                                 measurements_list, 
                                                                                                 num_correct_measurements, 
                                                                                                 unitaries, states, i)
        
        total_num_qubits_needed_across_all_trials = total_num_qubits_needed_across_all_trials + num_qubits_needed
        num_qubits_needed_across_all_trials.append(num_qubits_needed)
        
    _, _, _, confidence_interval = normal_distribution_confidence_interval(num_qubits_needed_across_all_trials)
    
    file_name_of_array = os.getcwd() + "/"
    file_name_of_array = file_name_of_array + "lambda_param_1_" + str(lambda_1)
    file_name_of_array = file_name_of_array + "_lambda_param_2_" + str(lambda_2)
    file_name_of_array = file_name_of_array + "_lambda_param_3_" + str(lambda_3)
    file_name_of_array = file_name_of_array + "_alpha_" + str(alpha)
    file_name_of_array = file_name_of_array + "_num_states_used_to_estimate_" + str(num_states_used_to_estimate_fidelity)
    file_name_of_array = file_name_of_array.replace(".", "")
    np.save(file_name_of_array, np.array(num_qubits_needed_across_all_trials))
    
    return total_num_qubits_needed_across_all_trials/num_trials, confidence_interval


In [None]:
num_states_used_to_estimate_fidelity = 500
num_trials = 500
get_qubits_needed_to_detect_possible_replay(num_trials, num_states_used_to_estimate_fidelity, 0.9, 0.5, 0.9, alpha)

Running the simulation with a variety of parameters. Results are printed in plaintext and latex code tables.

In [None]:
num_trials = 100
alphas_to_try = [0.25,0.35]
num_states_used_to_estimate_to_try = [250, 500, 1000]
lambda_1_and_3_to_try = [0.85,0.9,0.95]
lambda_2_to_try = [0.85,0.9,0.95,1.0]

headers = ["alpha","number of states used to estimate","lambda_1","lambda_2","lambda_3","qubits received before noticing delay"]
my_data = []

for alpha in alphas_to_try:
    for num_states_used_to_estimate in num_states_used_to_estimate_to_try:
        for lambda_1_and_3 in lambda_1_and_3_to_try:
            for lambda_2 in lambda_2_to_try:
                print(f"{alpha=} {num_states_used_to_estimate=} {lambda_1_and_3=} {lambda_2=}")
                _, confidence_interval = get_qubits_needed_to_detect_possible_replay(num_trials, num_states_used_to_estimate, 
                                                                                     lambda_1_and_3, lambda_2, lambda_1_and_3, 
                                                                                     alpha)
                my_data.append((alpha, num_states_used_to_estimate, lambda_1_and_3, lambda_2, lambda_1_and_3, 
                                confidence_interval))
                print(tabulate(my_data, headers, tablefmt='grid'))

print(tabulate(my_data, headers, tablefmt='latex'))
print(tabulate(my_data, headers, tablefmt='grid'))