In [3]:
import pennylane as qml
import numpy as np
import hashlib

##############################################################################
#                           Global Device Setup                               #
##############################################################################

dev = qml.device("default.mixed", wires=1, shots=1)

##############################################################################
#                    Helper function: +1/-1 to bit (0/1)                      #
##############################################################################

def bit_from_pauli_result(meas_result):
    """Convert a Pauli measurement result (+1 or -1) into a classical bit (0 or 1)."""
    return 0 if meas_result == 1 else 1

##############################################################################
#                QNodes: send_and_measure, eve_measure, eve_resend           #
##############################################################################

@qml.qnode(dev)
def send_and_measure(bit, alice_basis, bob_basis, noise=0.0):
    """
    1) Alice encodes 'bit' in basis (Z=0, X=1).
    2) Bob measures in 'bob_basis'.
    3) Returns the single-shot outcome (+1 or -1).
    """
    if alice_basis == 0:  # Z-basis
        if bit == 1:
            qml.PauliX(wires=0)
    else:  # X-basis
        qml.Hadamard(wires=0)
        if bit == 1:
            qml.PauliZ(wires=0)

    if noise > 0.0:
        qml.BitFlip(noise, wires=0)

    if bob_basis == 0:
        return qml.sample(qml.PauliZ(0))
    else:
        return qml.sample(qml.PauliX(0))

@qml.qnode(dev)
def eve_measure(bit, alice_basis, eve_basis, noise=0.0):
    """
    Eve intercepts the qubit:
      1) Same preparation as Alice.
      2) Optional noise.
      3) Measurement in 'eve_basis'.
    """
    if alice_basis == 0:
        if bit == 1:
            qml.PauliX(wires=0)
    else:
        qml.Hadamard(wires=0)
        if bit == 1:
            qml.PauliZ(wires=0)

    if noise > 0.0:
        qml.BitFlip(noise, wires=0)

    if eve_basis == 0:
        return qml.sample(qml.PauliZ(0))
    else:
        return qml.sample(qml.PauliX(0))

@qml.qnode(dev)
def eve_resend(eve_outcome, eve_basis, bob_basis, noise=0.0):
    """
    Eve re-encodes a qubit after measuring outcome +1 or -1 in 'eve_basis',
    then sends it to Bob, who measures in 'bob_basis'.
    """
    eve_bit = bit_from_pauli_result(eve_outcome)

    if eve_basis == 0:
        if eve_bit == 1:
            qml.PauliX(wires=0)
    else:
        qml.Hadamard(wires=0)
        if eve_bit == 1:
            qml.PauliZ(wires=0)

    if noise > 0.0:
        qml.BitFlip(noise, wires=0)

    if bob_basis == 0:
        return qml.sample(qml.PauliZ(0))
    else:
        return qml.sample(qml.PauliX(0))

##############################################################################
#                         BB84 Protocol Function                              #
##############################################################################

def generate_random_bits_and_bases(n):
    """Generate 'n' random bits and 'n' random bases (0=Z, 1=X)."""
    bits = np.random.randint(2, size=n)
    bases = np.random.randint(2, size=n)
    return bits, bases

def run_bb84_protocol(
    n=1000,
    sample_size=50,
    with_eve=False,
    p_eve=1.0,
    noise=0.0,
    error_threshold=1.0,
    do_info_reconciliation=False,
    do_privacy_amplification=False,
    block_size=16,
    pa_ratio=0.5
):
    """
    Executes BB84 as follows:
      1) Alice and Bob each choose random bits/bases.
      2) Qubits are sent from Alice to Bob. If eavesdropping is active
         with probability p_eve, Eve intercepts and resends.
      3) Only bits where Alice's and Bob's bases match are retained.
      4) A subset of these bits (sample_size) is used to estimate error rate.
         If error_rate > error_threshold, the final key is discarded.
      5) Optionally perform information reconciliation to correct
         up to one bit error per block.
      6) Optionally perform privacy amplification to reduce information leakage.

    Returns:
      (final_alice_key, final_bob_key, error_rate, matched_size, final_size)
    """
    # 1) Generate random bits/bases
    alice_bits, alice_bases = generate_random_bits_and_bases(n)
    _, bob_bases = generate_random_bits_and_bases(n)

    # 2) Transmission
    results = []
    for bit, a_basis, b_basis in zip(alice_bits, alice_bases, bob_bases):
        if with_eve and np.random.rand() < p_eve:
            eve_basis = np.random.randint(2)
            eve_outcome = eve_measure(bit, a_basis, eve_basis, noise=noise)
            eve_val = np.ravel(eve_outcome)[0]
            bob_outcome = eve_resend(eve_val, eve_basis, b_basis, noise=noise)
            bob_val = np.ravel(bob_outcome)[0]
            meas_bit = bit_from_pauli_result(bob_val)
        else:
            outcome = send_and_measure(bit, a_basis, b_basis, noise=noise)
            outcome_val = np.ravel(outcome)[0]
            meas_bit = bit_from_pauli_result(outcome_val)
        results.append(meas_bit)

    results = np.array(results, dtype=int)

    # 3) Keep bits where bases match
    matched_mask = (alice_bases == bob_bases)
    alice_key = alice_bits[matched_mask]
    bob_key   = results[matched_mask]
    matched_size = len(alice_key)

    if matched_size < sample_size:
        return np.array([]), np.array([]), 1.0, matched_size, 0

    # 4) Use a subset to estimate error rate
    sample_indices = np.random.choice(matched_size, size=sample_size, replace=False)
    alice_sample = alice_key[sample_indices]
    bob_sample   = bob_key[sample_indices]
    error_rate   = np.mean(alice_sample != bob_sample)

    remain_indices = np.setdiff1d(np.arange(matched_size), sample_indices)
    final_alice_key = alice_key[remain_indices]
    final_bob_key   = bob_key[remain_indices]
    final_size = len(final_alice_key)

    if error_rate > error_threshold:
        return np.array([]), np.array([]), error_rate, matched_size, 0

    # 5) Information Reconciliation (if enabled)
    if do_info_reconciliation and final_size > 0:
        final_alice_key, final_bob_key = one_way_information_reconciliation(
            final_alice_key, final_bob_key, block_size=block_size
        )

    # 6) Privacy Amplification (if enabled)
    if do_privacy_amplification and len(final_alice_key) > 0:
        final_alice_key, final_bob_key = privacy_amplification(
            final_alice_key, final_bob_key, ratio=pa_ratio
        )

    final_size = len(final_alice_key)
    return final_alice_key, final_bob_key, error_rate, matched_size, final_size

##############################################################################
#              One-Way Information Reconciliation (Parity-Block)             #
##############################################################################

def one_way_information_reconciliation(alice_key, bob_key, block_size=16):
    """
    Performs one-way information reconciliation with parity blocks:
      1) Alice splits her key into blocks of size block_size.
      2) For each block, Alice computes parity and sends it to Bob.
      3) Bob compares with his parity. If different, a binary search is
         performed. During each step, Alice reveals the parity of a sub-block;
         Bob compares and narrows down the single erroneous bit index.
      4) Bob flips that erroneous bit in his key.
      5) All bits revealed in parity checks are removed from both keys.
    """
    if len(alice_key) != len(bob_key):
        raise ValueError("Key lengths do not match for reconciliation.")

    alice_key = alice_key.copy()
    bob_key   = bob_key.copy()
    n = len(alice_key)
    revealed_indices = []

    start = 0
    while start < n:
        end = min(start + block_size, n)
        parity_a = np.sum(alice_key[start:end]) % 2
        parity_b = np.sum(bob_key[start:end]) % 2

        if parity_a != parity_b:
            left = start
            right = end - 1
            while left < right:
                mid = (left + right) // 2

                # Reveal parity of sub-block
                sub_parity_a = np.sum(alice_key[left:mid+1]) % 2
                sub_parity_b = np.sum(bob_key[left:mid+1]) % 2
                revealed_indices.extend(range(left, mid+1))

                if sub_parity_a != sub_parity_b:
                    right = mid
                else:
                    left = mid + 1

            bob_key[left] = 1 - bob_key[left]
            revealed_indices.append(left)

        start += block_size

    revealed_indices = sorted(set(revealed_indices))
    alice_out, bob_out = [], []
    remove_ptr = 0
    remove_len = len(revealed_indices)
    next_rm = revealed_indices[remove_ptr] if remove_len > 0 else None

    for idx in range(n):
        if (next_rm is not None) and (idx == next_rm):
            remove_ptr += 1
            if remove_ptr < remove_len:
                next_rm = revealed_indices[remove_ptr]
            else:
                next_rm = None
        else:
            alice_out.append(alice_key[idx])
            bob_out.append(bob_key[idx])

    return np.array(alice_out, dtype=int), np.array(bob_out, dtype=int)

##############################################################################
#                      Privacy Amplification (SHA-256)                       #
##############################################################################

def privacy_amplification(alice_key, bob_key, ratio=0.5):
    """
    Performs privacy amplification on the reconciled key using SHA-256:
      1) Both parties compute the same bitstring from their identical keys.
      2) A salt is appended.
      3) They take the SHA-256 hash, convert to binary, and truncate
         to min(256, ratio * original_length).
    """
    if len(alice_key) != len(bob_key):
        raise ValueError("Key lengths do not match before privacy amplification.")

    if not np.array_equal(alice_key, bob_key):
        return np.array([], dtype=int), np.array([], dtype=int)

    n = len(alice_key)
    if n == 0:
        return alice_key, bob_key

    alice_str = "".join(str(b) for b in alice_key)
    salt = b"QKD-demo-salt"
    hash_obj = hashlib.sha256(salt + alice_str.encode("utf-8"))
    digest = hash_obj.digest()
    digest_bits = "".join(f"{byte:08b}" for byte in digest)

    pa_length = int(np.floor(ratio * n))
    pa_length = max(pa_length, 1)
    pa_length = min(pa_length, 256)

    final_bitstring = digest_bits[:pa_length]
    final_bits = np.array([int(c) for c in final_bitstring], dtype=int)
    return final_bits, final_bits

##############################################################################
#              Run multiple scenarios and print results                      #
##############################################################################

def run_scenario_multiple_times(
    label,
    n_runs,
    n=1000,
    sample_size=50,
    with_eve=False,
    p_eve=1.0,
    noise=0.0,
    error_threshold=1.0,
    do_info_reconciliation=False,
    do_privacy_amplification=False
):
    """
    Runs the BB84 protocol 'n_runs' times for a particular scenario
    and prints aggregated results (error rates, matched sizes, final sizes).
    """
    print(f"===== CASE {label} =====")
    print(f"Parameters: n={n}, sample_size={sample_size}, with_eve={with_eve}, "
          f"p_eve={p_eve}, noise={noise}, Reconciliation={do_info_reconciliation}, "
          f"PA={do_privacy_amplification}")

    errors = []
    matched_sz = []
    final_sz = []

    for i in range(n_runs):
        fa_key, fb_key, err, match_size, fin_size = run_bb84_protocol(
            n=n,
            sample_size=sample_size,
            with_eve=with_eve,
            p_eve=p_eve,
            noise=noise,
            error_threshold=error_threshold,
            do_info_reconciliation=do_info_reconciliation,
            do_privacy_amplification=do_privacy_amplification
        )
        errors.append(err)
        matched_sz.append(match_size)
        final_sz.append(fin_size)
        print(f"  Run {i+1}/{n_runs} => Error={err*100:.2f}%, Matched={match_size}, Final={fin_size}")

    errors = np.array(errors)
    matched_sz = np.array(matched_sz)
    final_sz = np.array(final_sz)

    avg_err = np.mean(errors)
    std_err = np.std(errors)
    min_err = np.min(errors)
    max_err = np.max(errors)

    avg_matched = np.mean(matched_sz)
    std_matched = np.std(matched_sz)
    min_matched = np.min(matched_sz)
    max_matched = np.max(matched_sz)

    avg_final = np.mean(final_sz)
    std_final = np.std(final_sz)
    min_final = np.min(final_sz)
    max_final = np.max(final_sz)

    print(f"\nSummary for {label}:")
    print(f"  Error Rate:  avg={avg_err*100:.2f}% ± {std_err*100:.2f}% "
          f"[min={min_err*100:.2f}%, max={max_err*100:.2f}%]")
    print(f"  Matched Sz:  avg={avg_matched:.1f} ± {std_matched:.1f} "
          f"[min={min_matched}, max={max_matched}]")
    print(f"  Final Sz:    avg={avg_final:.1f} ± {std_final:.1f} "
          f"[min={min_final}, max={max_final}]")
    print("="*60, "\n")

    return errors, matched_sz, final_sz

##############################################################################
#                               Main Execution                               #
##############################################################################

if __name__ == "__main__":
    np.random.seed(42)

    # Number of repetitions per scenario
    n_runs = 5

    # Key parameters
    n_qubits = 1000
    sample_check = 50
    noise_probability = 0.02
    threshold = 0.2

    # Case A
    run_scenario_multiple_times(
        label="A: No Eavesdrop, No Noise",
        n_runs=n_runs,
        n=n_qubits,
        sample_size=sample_check,
        with_eve=False,
        p_eve=0.0,
        noise=0.0,
        error_threshold=threshold,
        do_info_reconciliation=False,
        do_privacy_amplification=False
    )

    # Case B
    run_scenario_multiple_times(
        label="B: No Eavesdrop, No Noise +Reconciliation +PA",
        n_runs=n_runs,
        n=n_qubits,
        sample_size=sample_check,
        with_eve=False,
        p_eve=0.0,
        noise=0.0,
        error_threshold=threshold,
        do_info_reconciliation=True,
        do_privacy_amplification=True
    )

    # Case C
    run_scenario_multiple_times(
        label="C: Eavesdrop + Noise +Reconciliation +PA",
        n_runs=n_runs,
        n=n_qubits,
        sample_size=sample_check,
        with_eve=True,
        p_eve=1.0,
        noise=noise_probability,
        error_threshold=threshold,
        do_info_reconciliation=True,
        do_privacy_amplification=True
    )

    # Case D
    run_scenario_multiple_times(
        label="D: Noise Only +Reconciliation +PA",
        n_runs=n_runs,
        n=n_qubits,
        sample_size=sample_check,
        with_eve=False,
        p_eve=0.0,
        noise=noise_probability,
        error_threshold=threshold,
        do_info_reconciliation=True,
        do_privacy_amplification=True
    )


===== CASE A: No Eavesdrop, No Noise =====
Parameters: n=1000, sample_size=50, with_eve=False, p_eve=0.0, noise=0.0, Reconciliation=False, PA=False
  Run 1/5 => Error=0.00%, Matched=500, Final=450
  Run 2/5 => Error=0.00%, Matched=499, Final=449
  Run 3/5 => Error=0.00%, Matched=489, Final=439
  Run 4/5 => Error=0.00%, Matched=511, Final=461
  Run 5/5 => Error=0.00%, Matched=509, Final=459

Summary for A: No Eavesdrop, No Noise:
  Error Rate:  avg=0.00% ± 0.00% [min=0.00%, max=0.00%]
  Matched Sz:  avg=501.6 ± 7.9 [min=489, max=511]
  Final Sz:    avg=451.6 ± 7.9 [min=439, max=461]

===== CASE B: No Eavesdrop, No Noise +Reconciliation +PA =====
Parameters: n=1000, sample_size=50, with_eve=False, p_eve=0.0, noise=0.0, Reconciliation=True, PA=True
  Run 1/5 => Error=0.00%, Matched=491, Final=220
  Run 2/5 => Error=0.00%, Matched=514, Final=232
  Run 3/5 => Error=0.00%, Matched=488, Final=219
  Run 4/5 => Error=0.00%, Matched=466, Final=208
  Run 5/5 => Error=0.00%, Matched=498, Final=224