In [1]:
# %% [markdown]
# Import necessary libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Subset
import torch.optim as optim
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_absolute_error
import json
from tqdm import tqdm
import torch.nn.functional as F
import joblib
import os
import sys
from torch.utils.data import DataLoader, random_split
import time  # For timing

# Set matplotlib to use a non-interactive backend
plt.switch_backend('Agg')

# Get the notebook's directory
notebook_dir = os.getcwd()
# Add parent directory to path
project_root = os.path.dirname(notebook_dir)
sys.path.append(project_root)

# %% [markdown]
# Checking if the MPS (Metal Performance Shaders) backend is available

# %%
# Check if MPS is available
if not torch.backends.mps.is_available():
    raise RuntimeError("MPS device not available. Check if PyTorch and macOS set up correctly.")

# Set the device to MPS
device = torch.device("mps")  # Use GPU on M2 Pro

In [2]:
# %%
# Import the objective function
from QKD_Functions.QKD_Functions import (
    calculate_factorial, calculate_tau_n, calculate_eta_ch, calculate_eta_sys,
    calculate_D_mu_k, calculate_n_X_total, calculate_N, calculate_n_Z_total,
    calculate_e_mu_k, calculate_e_obs, calculate_h, calculate_lambda_EC,
    calculate_sqrt_term, calculate_n_pm, calculate_S_0, calculate_S_1,
    calculate_m_mu_k, calculate_m_pm, calculate_v_1, calculate_gamma,
    calculate_Phi, calculate_LastTwoTerm, calculate_l, calculate_R,
    experimental_parameters, other_parameters, calculate_key_rates_and_metrics,
    penalty, objective,
)

  from tqdm.autonotebook import tqdm


In [3]:
# Define a safe wrapper for the objective function
# def safe_objective(params, L, nx, **kwargs):
#     try:
#         key_rate = objective(params, L, nx, **kwargs)[0]
#         if np.isnan(key_rate) or np.isinf(key_rate) or key_rate <= 0:
#             print(f"Invalid key rate (NaN, Inf, or <= 0) for params {params}, L={L}, nx={nx}: {key_rate}")
#             return None  # Return None to indicate invalid result
#         return key_rate
#     except Exception as e:
#         print(f"Error in objective for params {params}, L={L}, nx={nx}: {e}")
#         return None  # Return None on error

def safe_objective(params, L, nx, **kwargs):
    try:
        # Ensure params are in a numpy array for checking
        params_np = np.array(params)
        # Check for NaN/Inf in input params themselves (can happen from NN)
        if np.any(np.isnan(params_np)) or np.any(np.isinf(params_np)):
             print(f"Invalid input params (NaN/Inf): {params_np}, L={L}, nx={nx}")
             return None

        key_rate = objective(params, L, nx, **kwargs)[0]

        if np.isnan(key_rate):
            print(f"Objective returned NaN for params {params}, L={L}, nx={nx}")
            return None
        if np.isinf(key_rate):
            print(f"Objective returned Inf for params {params}, L={L}, nx={nx}")
            return None
        if key_rate <= 0:
            # Be slightly more tolerant? Maybe allow exactly 0?
            # Or keep <= 0 but report it
            print(f"Objective returned non-positive key rate ({key_rate}) for params {params}, L={L}, nx={nx}")
            # Decide if you want to return None or the non-positive value
            return None # Keep filtering non-positive for now
        return key_rate
    except Exception as e:
        print(f"Exception in objective for params {params}, L={L}, nx={nx}: {e}")
        # Add traceback for more detail if needed:
        # import traceback
        # print(traceback.format_exc())
        return None
    
# %%
# Load dataset
with open('../Training_Data/n_X/good/cleaned_combined_datasets.json', 'r') as f:
    data_by_nx = json.load(f)

print(f"The overall dataset contains {len(data_by_nx)} entries (number of unique n_X values).")

# Print the number of entries for each n_X before filtering
for n_x in data_by_nx.keys():
    print(f"Number of entries for n_X = {n_x}: {len(data_by_nx[n_x])}")

The overall dataset contains 6 entries (number of unique n_X values).
Number of entries for n_X = 10000.0: 736
Number of entries for n_X = 100000.0: 855
Number of entries for n_X = 1000000.0: 902
Number of entries for n_X = 10000000.0: 927
Number of entries for n_X = 100000000.0: 942
Number of entries for n_X = 1000000000.0: 948


In [4]:
# Flatten the data structure and filter
cleaned_data = []
for n_x, entries in data_by_nx.items():
    filtered_entries = [item for item in entries if item["key_rate"] > 0 and item["e_1"] * 100 <= 200]
    print(f"After filtering, n_X = {n_x} has {len(filtered_entries)} entries.")
    cleaned_data.extend(filtered_entries)

# Verify the cleaned dataset
if not cleaned_data:
    print("No valid data after filtering.")
else:
    print(f"Filtered dataset contains {len(cleaned_data)} entries.")
    print("\nSample entry from the cleaned dataset:")
    print(json.dumps(cleaned_data[0], indent=2))
    print("\nNumber of unique n_X values:", len(data_by_nx))

# %%
X = np.array([[item['e_1'], item['e_2'], item['e_3'], item['e_4']] for item in cleaned_data])
Y = np.array([[item['optimized_params']['mu_1'], item['optimized_params']['mu_2'], item['optimized_params']['P_mu_1'], item['optimized_params']['P_mu_2'], item['optimized_params']['P_X_value']] for item in cleaned_data])

After filtering, n_X = 10000.0 has 736 entries.
After filtering, n_X = 100000.0 has 855 entries.
After filtering, n_X = 1000000.0 has 902 entries.
After filtering, n_X = 10000000.0 has 927 entries.
After filtering, n_X = 100000000.0 has 942 entries.
After filtering, n_X = 1000000000.0 has 948 entries.
Filtered dataset contains 5310 entries.

Sample entry from the cleaned dataset:
{
  "fiber_length": 0.0,
  "e_1": 0.0,
  "e_2": 6.221848749616356,
  "e_3": 0.5,
  "e_4": 4.0,
  "key_rate": 0.00037995895580899045,
  "optimized_params": {
    "mu_1": 0.6126301275287198,
    "mu_2": 0.13818683663411718,
    "P_mu_1": 0.04604299826386499,
    "P_mu_2": 0.6110117432731096,
    "P_X_value": 0.414731486913827
  }
}

Number of unique n_X values: 6


In [None]:
from sklearn.utils import shuffle

X, Y = shuffle(X, Y, random_state=42)

scaler = StandardScaler()
X = scaler.fit_transform(X)  # Fit and transform on training data

y_scaler = MinMaxScaler()  # Scale targets to [0, 1]
Y = y_scaler.fit_transform(Y)

# Save the scalers
joblib.dump(scaler, '../models/scaler.pkl')  # Save StandardScaler
joblib.dump(y_scaler, '../models/y_scaler.pkl')  # Save MinMaxScaler

print(f"X shape: {X.shape}, Y shape: {Y.shape}")
dataset = TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.float32))

# %%
# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Data loaders with increased batch size
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

X shape: (5310, 4), Y shape: (5310, 5)


In [6]:
# %%
# Load evaluation data for n_X = 5e8
with open("../n_X/good/5e8_100_reordered_qkd_grouped_dataset_20250226_191349.json", 'r') as f:
    evaluation_dataset = json.load(f)

# Select target n_X for evaluation
target_nx = 5 * 100000000  # 5e8
nx_key = str(float(target_nx))
if nx_key not in evaluation_dataset:
    raise ValueError(f"No data found for n_X = {target_nx}")
evaluation_data = evaluation_dataset[nx_key]

# Verify the size of the evaluation dataset
print(f"Number of data points for n_X = {target_nx}: {len(evaluation_data)}")

Number of data points for n_X = 500000000: 100


In [7]:
# Extract evaluation fiber lengths and optimized parameters
fiber_lengths_un_seen = np.array([entry["fiber_length"] for entry in evaluation_data])
optimized_params_array_un_seen = np.array([list(entry["optimized_params"].values()) for entry in evaluation_data])

# Evaluate optimized key rates for evaluation
optimized_key_rates_un_seen = []
valid_indices_un_seen = []
for idx, (params, L) in enumerate(zip(optimized_params_array_un_seen, fiber_lengths_un_seen)):
    key_rate = safe_objective(params, L, target_nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, 
                              epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
    if key_rate is None:
        print(f"Skipping invalid key rate at index {idx}")
    else:
        optimized_key_rates_un_seen.append(key_rate)
        valid_indices_un_seen.append(idx)
optimized_key_rates_un_seen = np.array(optimized_key_rates_un_seen)

# Filter the evaluation data to exclude invalid key rates
fiber_lengths_un_seen = fiber_lengths_un_seen[valid_indices_un_seen]
optimized_params_array_un_seen = optimized_params_array_un_seen[valid_indices_un_seen]

Objective returned non-positive key rate (-2.5459075744215506e-31) for params [5.21125578e-01 3.37384994e-01 1.92634353e-01 6.63224253e-01
 1.00000000e-12], L=189.89898989898992, nx=500000000
Skipping invalid key rate at index 94
Objective returned non-positive key rate (-1.0847670075212521e-30) for params [9.00000000e-01 4.09777073e-01 1.21908112e-01 6.52515070e-01
 1.00000000e-12], L=191.91919191919195, nx=500000000
Skipping invalid key rate at index 95
Objective returned non-positive key rate (-1.0756742243147874e-30) for params [8.13544868e-01 2.57789157e-01 3.75126595e-01 1.07065241e-01
 1.00000000e-12], L=193.93939393939394, nx=500000000
Skipping invalid key rate at index 96
Objective returned non-positive key rate (-5.9934352286340745e-31) for params [5.53278979e-01 2.55719551e-01 4.01755354e-01 4.61925236e-01
 1.00000000e-12], L=195.95959595959596, nx=500000000
Skipping invalid key rate at index 97
Objective returned non-positive key rate (-8.89190788063512e-31) for params [7.5

In [8]:
# Prepare test inputs for evaluation (n_X = 5e8)
X_test_un_seen = []
for L in fiber_lengths_un_seen:
    e_1 = L / 100
    e_2 = -np.log10(6e-7)
    e_3 = 5e-3 * 100
    e_4 = np.log10(target_nx)
    X_test_un_seen.append([e_1, e_2, e_3, e_4])

X_test_un_seen = np.array(X_test_un_seen)
X_test_scaled_un_seen = scaler.transform(X_test_un_seen)
X_test_tensor_un_seen = torch.tensor(X_test_scaled_un_seen, dtype=torch.float32).to(device)
print(f"After filtering NaN in optimized_key_rates_un_seen, {len(optimized_key_rates_un_seen)} data points remain for n_X = 5e8.")

After filtering NaN in optimized_key_rates_un_seen, 94 data points remain for n_X = 5e8.


In [9]:
# %% [markdown]
# ### Prepare Evaluation Data for All n_X Values (10^4 to 10^9) Before Training

# Load the combined dataset
combined_file_path = '../Training_Data/n_X/good/cleaned_combined_datasets.json'
try:
    with open(combined_file_path, 'r') as f:
        combined_data = json.load(f)
    print(f"Loaded combined dataset with {len(combined_data)} n_X entries.")
except FileNotFoundError:
    raise FileNotFoundError(f"Combined dataset not found at {combined_file_path}.")

Loaded combined dataset with 6 n_X entries.


In [10]:
# Define the range of n_X values to evaluate
nx_values = [10**s for s in range(4, 10)]  # 10^4 to 10^9

# Dictionary to store evaluation data for each n_X
all_evaluation_data = {}

In [11]:
# Extract and prepare evaluation data for each n_X
for nx in nx_values:
    nx_key = str(float(nx))
    if nx_key not in combined_data:
        print(f"No data found for n_X = {nx} in the combined dataset. Skipping...")
        continue
    evaluation_data = combined_data[nx_key]
    print(f"Extracted {len(evaluation_data)} entries for n_X = {nx}.")

    # Extract fiber lengths and optimized parameters
    fiber_lengths = np.array([entry["fiber_length"] for entry in evaluation_data])
    optimized_params_array = np.array([list(entry["optimized_params"].values()) for entry in evaluation_data])

    # Compute optimized key rates
    optimized_key_rates = []
    valid_indices = []
    for idx, (params, L) in enumerate(zip(optimized_params_array, fiber_lengths)):
        key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, 
                                  epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, 
                                  e_mis=5e-3, P_ap=0, n_event=1)
        if key_rate is None:
            print(f"Skipping invalid optimized key rate at index {idx} for n_X = {nx}")
            continue
        optimized_key_rates.append(key_rate)
        valid_indices.append(idx)
    optimized_key_rates = np.array(optimized_key_rates)

    # Filter the evaluation data to exclude invalid key rates
    fiber_lengths = fiber_lengths[valid_indices]
    optimized_params_array = optimized_params_array[valid_indices]

    # Prepare test inputs for evaluation
    X_test = []
    for L in fiber_lengths:
        e_1 = L / 100
        e_2 = -np.log10(6e-7)
        e_3 = 5e-3 * 100
        e_4 = np.log10(nx)
        X_test.append([e_1, e_2, e_3, e_4])
    X_test = np.array(X_test)
    X_test_scaled = scaler.transform(X_test)
    X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32).to(device)

    # Store all evaluation data in a dictionary
    all_evaluation_data[nx] = {
        'fiber_lengths': fiber_lengths,
        'optimized_params_array': optimized_params_array,
        'optimized_key_rates': optimized_key_rates,
        'X_test_tensor': X_test_tensor
    }
    print(f"After filtering invalid key rates, {len(optimized_key_rates)} data points remain for n_X = {nx}.")

Extracted 736 entries for n_X = 10000.
After filtering invalid key rates, 736 data points remain for n_X = 10000.
Extracted 855 entries for n_X = 100000.
After filtering invalid key rates, 855 data points remain for n_X = 100000.
Extracted 902 entries for n_X = 1000000.
After filtering invalid key rates, 902 data points remain for n_X = 1000000.
Extracted 927 entries for n_X = 10000000.
After filtering invalid key rates, 927 data points remain for n_X = 10000000.
Extracted 942 entries for n_X = 100000000.
After filtering invalid key rates, 942 data points remain for n_X = 100000000.
Extracted 948 entries for n_X = 1000000000.
After filtering invalid key rates, 948 data points remain for n_X = 1000000000.


In [12]:
# %%
# Define the neural network model
class BB84NN(nn.Module):
    def __init__(self):
        super(BB84NN, self).__init__()
        self.fc1 = nn.Linear(4, 16)
        self.fc2 = nn.Linear(16, 32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [13]:
# Initialize model, loss, optimizer, and scheduler
model = BB84NN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-6)

# Lists to store metrics
train_losses = []
val_losses = []
learning_rates = []

In [14]:
def plot_keyrate_subplots(all_data, epoch, filename):
    """
    Plots key rates vs fiber length for multiple n_X values on separate subplots. all_data: dict with keys as n_X values and values containing fiber_lengths, optimized_key_rates, predicted_key_rates.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    fig.suptitle(f"Key Rates vs Fiber Length at Epoch {epoch+1} ($n_X = 10^s, s=4,5,...,9$)", fontsize=16)

    axes = axes.flatten()
    nx_values = sorted(all_data.keys())  # Ensure n_X values are sorted

    for i, nx in enumerate(nx_values):
        if i >= 6:  # Limit to 6 subplots
            break

        ax = axes[i]
        fiber_lengths = all_data[nx]['fiber_lengths']
        optimized_key_rates = all_data[nx]['optimized_key_rates']
        predicted_key_rates = all_data[nx].get('predicted_key_rates', None)

        # Filter out NaN values
        valid_mask = ~(np.isnan(optimized_key_rates))
        if predicted_key_rates is not None:
            valid_mask &= ~np.isnan(predicted_key_rates)
        fiber_lengths = fiber_lengths[valid_mask]
        optimized_key_rates = optimized_key_rates[valid_mask]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[valid_mask]

        # Find cutoff where optimized key rate becomes very small
        threshold = 1e-8
        cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
        if len(cutoff_idx) > 0:
            cutoff_idx = cutoff_idx[0]
        else:
            cutoff_idx = len(fiber_lengths)

        fiber_lengths = fiber_lengths[:cutoff_idx]
        optimized_key_rates = optimized_key_rates[:cutoff_idx]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[:cutoff_idx]

        # Plot key rates with adjusted linestyles and transparency
        ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linestyle='-', linewidth=2.0, alpha=1.0)  # Thicker line, semi-transparent
        if predicted_key_rates is not None and len(predicted_key_rates) > 0:
            ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r', label='Predicted Key Rate', 
                    linestyle=':',
                    # linestyle=(0, (1, 10)),  # Short dash (1 point) to look like a dot, 10-point gap
                    linewidth=4.0, alpha=0.7) # Thicker line, fully opaque
        
                
        # Compute the exponent for nx (e.g., nx = 10000 -> exponent = 4)
        exponent = int(np.log10(nx))
        # Set the title with LaTeX formatting for 10^exponent
        ax.set_title(f"$n_X = 10^{{{exponent}}}$")
        ax.grid(True)

        if i >= 3:
            ax.set_xlabel("Fiber Length (km)")
        if i % 3 == 0:
            ax.set_ylabel("Secret Key Rate per Pulse")

        ax.legend(loc='upper right')

    # Remove empty subplots
    for i in range(len(nx_values), 6):
        fig.delaxes(axes[i])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close(fig)

    print(f"Key rate subplots saved to {filename}")

In [15]:
def plot_parameters_subplots(all_data, epoch, filename):
    """
    Plots optimized and predicted parameters vs fiber length for multiple n_X values on separate subplots.
    all_data: dict with keys as n_X values and values containing fiber_lengths, optimized_params_array, predicted_params_array.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    fig.suptitle(f"Optimized and Predicted Parameters vs Fiber Length at Epoch {epoch+1} ($n_X = 10^s, s=4,5,...,9$)", fontsize=16)

    axes = axes.flatten()
    nx_values = sorted(all_data.keys())

    param_labels = ['$P_X$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$\mu_1$', '$\mu_2$']
    colors = ['blue', 'orange', 'green', 'red', 'purple']

    # Store handles and labels for the shared legend
    handles, labels = [], []

    for i, nx in enumerate(nx_values):
        if i >= 6:
            break

        ax = axes[i]
        fiber_lengths = all_data[nx]['fiber_lengths']
        optimized_params_array = all_data[nx]['optimized_params_array']
        predicted_params_array = all_data[nx].get('predicted_params_array', None)

        # Filter out NaN values in key rates to align with key rate plots
        optimized_key_rates = all_data[nx]['optimized_key_rates']
        valid_mask = ~(np.isnan(optimized_key_rates))
        fiber_lengths = fiber_lengths[valid_mask]
        optimized_params_array = optimized_params_array[valid_mask]
        if predicted_params_array is not None:
            predicted_params_array = predicted_params_array[valid_mask]

        # Truncate based on key rate cutoff
        threshold = 1e-8
        cutoff_idx = np.where(optimized_key_rates[valid_mask] <= threshold)[0]
        if len(cutoff_idx) > 0:
            cutoff_idx = cutoff_idx[0]
        else:
            cutoff_idx = len(fiber_lengths)

        fiber_lengths = fiber_lengths[:cutoff_idx]
        optimized_params_array = optimized_params_array[:cutoff_idx]
        if predicted_params_array is not None:
            predicted_params_array = predicted_params_array[:cutoff_idx]

        # Plot parameters with slight transparency for predicted lines
        for param_idx, (label, color) in enumerate(zip(param_labels, colors)):
            # Plot optimized line
            line_opt = ax.plot(fiber_lengths, optimized_params_array[:, param_idx], 
                               label=f'Optimized {label}', color=color, linestyle='-', linewidth=2.0, alpha=1.0)
            # Plot predicted line
            if predicted_params_array is not None:
                line_pred = ax.plot(fiber_lengths, predicted_params_array[:, param_idx], 
                                    label=f'Predicted {label}', color=color, linestyle=':', 
                                    linewidth=4.0, alpha=0.7)
            
            # Collect handles and labels only from the first subplot (i=0)
            if i == 0:
                handles.append(line_opt[0])
                labels.append(f'Optimized {label}')
                if predicted_params_array is not None:
                    handles.append(line_pred[0])
                    labels.append(f'Predicted {label}')

        # Compute the exponent for nx (e.g., nx = 10000 -> exponent = 4)
        exponent = int(np.log10(nx))
        # Set the title with LaTeX formatting for 10^exponent
        ax.set_title(f"$n_X = 10^{{{exponent}}}$")
        
        ax.set_ylim(0.0, 1.0)
        ax.grid(True)

        if i >= 3:
            ax.set_xlabel("Fiber Length (km)")
        if i % 3 == 0:
            ax.set_ylabel("Parameter Value")

    # Remove empty subplots
    for i in range(len(nx_values), 6):
        fig.delaxes(axes[i])

    # Add a single shared legend outside the entire grid
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(0.92, 0.5))

    # Adjust layout to make space for the legend on the right
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])  # Leave 10% space on the right for the shared legend
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Parameter subplots saved to {filename}")

In [16]:
def plot_relative_errors(fiber_lengths, optimized_key_rates, optimized_params_array, 
                         predicted_key_rates, predicted_params_array, epoch, filename, 
                         nx=None, threshold=1e-7):  # Increase threshold to 1e-7
    # Start timing
    plot_start_time = time.time()

    valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
    fiber_lengths = fiber_lengths[valid_mask]
    optimized_key_rates = optimized_key_rates[valid_mask]
    predicted_key_rates = predicted_key_rates[valid_mask]
    optimized_params_array = optimized_params_array[valid_mask]
    predicted_params_array = predicted_params_array[valid_mask]

    # Find cutoff where optimized key rate becomes very small
    cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
    if len(cutoff_idx) > 0:
        cutoff_idx = cutoff_idx[0]
    else:
        cutoff_idx = len(fiber_lengths)

    fiber_lengths = fiber_lengths[:cutoff_idx]
    optimized_key_rates = optimized_key_rates[:cutoff_idx]
    predicted_key_rates = predicted_key_rates[:cutoff_idx]
    optimized_params_array = optimized_params_array[:cutoff_idx]
    predicted_params_array = predicted_params_array[:cutoff_idx]

    # Compute relative errors
    # Parameters: mu_1, mu_2, P_mu_1, P_mu_2, P_X
    relative_errors = []
    param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
    for i in range(5):
        optimized = optimized_params_array[:, i]
        predicted = predicted_params_array[:, i]
        denominator = np.maximum(optimized, 1e-10)
        rel_error = (predicted - optimized) / denominator
        relative_errors.append(rel_error)
    
    # Key rate relative error with a higher threshold
    denominator = np.maximum(optimized_key_rates, 1e-7)  # Increase threshold to 1e-7
    key_rate_rel_error = (predicted_key_rates - optimized_key_rates) / denominator

    # Create figure with 6 subplots (2 rows, 3 columns)
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    # Plot relative errors
    for i in range(5):
        row = i // 3
        col = i % 3
        ax = fig.add_subplot(gs[row, col])
        ax.plot(fiber_lengths, relative_errors[i], 'b-', label=f'Relative Error {param_labels[i]}')
        ax.set_xlabel('Fiber Length (km)')
        ax.set_ylabel('Relative Error')
        ax.set_title(f'Relative Error for {param_labels[i]}')
        ax.grid(True)
        ax.legend(loc='best')

    # Key rate subplot
    ax_key = fig.add_subplot(gs[1, 2])
    ax_key.plot(fiber_lengths, key_rate_rel_error, 'r-', label='Relative Error Key Rate')
    ax_key.set_xlabel('Fiber Length (km)')
    ax_key.set_ylabel('Relative Error')
    ax_key.set_title('Relative Error for Key Rate')
    ax_key.grid(True)
    ax_key.legend(loc='best')

    # Overall title
    exponent = int(np.log10(nx)) if nx is not None else ''
    fig.suptitle(f'Relative Errors of Parameters for $n_X = 10^{{{exponent}}}$', fontsize=16)

    # Save and close
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Relative error plot saved to {filename}")

    # Print plotting time
    plot_time = time.time() - plot_start_time
    print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [17]:
# Training loop
import time

# Training loop
total_start_time = time.time()
num_epochs = 5000
for epoch in range(num_epochs):
    start_time = time.time()

    # Training phase
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_loss)

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_running_loss += loss.item() * inputs.size(0)
    val_loss = val_running_loss / len(val_loader.dataset)
    val_losses.append(val_loss)

    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

    # Evaluate and plot at the first and last epochs
    eval_time = 0
    plot_time = 0
    if epoch == 0 or epoch == num_epochs - 1:
        model.eval()
        with torch.no_grad():
            # Prepare data for all n_X values
            plot_data = {}
            eval_start = time.time()
            for nx, eval_data in all_evaluation_data.items():
                fiber_lengths = eval_data['fiber_lengths']
                optimized_params_array = eval_data['optimized_params_array']
                optimized_key_rates = eval_data['optimized_key_rates']
                X_test_tensor = eval_data['X_test_tensor']

                # Compute predicted parameters and key rates
                predicted_params_scaled = model(X_test_tensor).cpu().numpy()
                predicted_params_scaled = np.clip(predicted_params_scaled, 0, 1)
                predicted_params_array = y_scaler.inverse_transform(predicted_params_scaled)
                predicted_params_array[:, 0] = np.maximum(predicted_params_array[:, 0], 1e-6) # mu_1
                predicted_params_array[:, 1] = np.maximum(predicted_params_array[:, 1], 1e-6) # mu_2
                predicted_params_array[:, 2] = np.clip(predicted_params_array[:, 2], 0, 1)    # P_mu_1
                predicted_params_array[:, 3] = np.clip(predicted_params_array[:, 3], 0, 1)    # P_mu_2
                predicted_params_array[:, 4] = np.clip(predicted_params_array[:, 4], 0, 1)    # P_X
                
                # --- NEW: Enforce Sum Constraint ---
                prob_sum = predicted_params_array[:, 2] + predicted_params_array[:, 3]
                # Option A: Normalize if sum > 1
                mask_sum_gt_1 = prob_sum > 1.0
                if np.any(mask_sum_gt_1):
                    predicted_params_array[mask_sum_gt_1, 2] /= prob_sum[mask_sum_gt_1] # Scale P_mu_1
                    predicted_params_array[mask_sum_gt_1, 3] /= prob_sum[mask_sum_gt_1] # Scale P_mu_2
                # Option B: Clip individual probabilities again (less physically motivated, but simpler)
                # predicted_params_array[:, 2] = np.minimum(predicted_params_array[:, 2], 1.0 - predicted_params_array[:, 3]) # Ensure P_mu_1 <= 1 - P_mu_2
                
                # Compute and filter predicted key rates
                predicted_key_rates = []
                valid_indices = []
                for idx, (params, L) in enumerate(zip(predicted_params_array, fiber_lengths)):
                    key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, 
                                              epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, 
                                              e_mis=5e-3, P_ap=0, n_event=1)
                    if key_rate is None:
                        print(f"Skipping invalid predicted key rate at index {idx} for n_X = {nx}")
                        continue
                    predicted_key_rates.append(key_rate)
                    valid_indices.append(idx)
                predicted_key_rates = np.array(predicted_key_rates)

                # Filter data based on valid predicted key rates
                fiber_lengths = fiber_lengths[valid_indices]
                optimized_key_rates = optimized_key_rates[valid_indices]
                optimized_params_array = optimized_params_array[valid_indices]
                predicted_params_array = predicted_params_array[valid_indices]

                # Store data for plotting
                plot_data[nx] = {
                    'fiber_lengths': fiber_lengths,
                    'optimized_key_rates': optimized_key_rates,
                    'predicted_key_rates': predicted_key_rates,
                    'optimized_params_array': optimized_params_array,
                    'predicted_params_array': predicted_params_array
                }

                # Plot relative errors for this n_X
                plot_relative_errors(
                    fiber_lengths, optimized_key_rates, optimized_params_array,
                    predicted_key_rates, predicted_params_array, epoch,
                    f'parameter_relative_error_nx_{nx:.0e}.png', nx=nx
                )

            eval_time += time.time() - eval_start

            # Plot subplots at the first and last epochs
            plot_start = time.time()
            if epoch == 0:
                plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_first_epoch.png')
                plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_first_epoch.png')
            if epoch == num_epochs - 1:
                plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_last_epoch.png')
                plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_last_epoch.png')
            plot_time += time.time() - plot_start

            torch.save(model.state_dict(), 'bb84_nn_model.pth')
            print("Model saved to bb84_nn_model.pth")

    # Print epoch results with timing and learning rate
    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Learning Rate: {current_lr:.6f}, Time: {epoch_time:.2f}s (Eval: {eval_time:.2f}s, Plot: {plot_time:.2f}s)")

# Print total training time
total_training_time = time.time() - total_start_time
print(f"--- Total Training Time: {total_training_time:.2f} seconds ---")

Objective returned NaN for params [0.5794454  0.2931385  0.18912245 0.77275324 0.7395108 ], L=0.0, nx=10000
Skipping invalid predicted key rate at index 0 for n_X = 10000
Objective returned NaN for params [0.5793322  0.2930146  0.18885249 0.77275324 0.739054  ], L=0.2002002002002002, nx=10000
Skipping invalid predicted key rate at index 1 for n_X = 10000
Objective returned NaN for params [0.57921886 0.29289067 0.18858255 0.77275324 0.7385973 ], L=0.4004004004004004, nx=10000
Skipping invalid predicted key rate at index 2 for n_X = 10000
Objective returned NaN for params [0.5791056  0.29276675 0.18831271 0.77275324 0.73814064], L=0.6006006006006006, nx=10000
Skipping invalid predicted key rate at index 3 for n_X = 10000
Objective returned NaN for params [0.5789923  0.29264283 0.1880428  0.77275324 0.737684  ], L=0.8008008008008008, nx=10000
Skipping invalid predicted key rate at index 4 for n_X = 10000
Objective returned NaN for params [0.57887906 0.2925189  0.18777284 0.77275324 0.7372

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+04.png
--- Plotting Time: 1.07 seconds ---
Objective returned NaN for params [0.5636161  0.2865196  0.1914458  0.77275324 0.71210897], L=0.0, nx=100000
Skipping invalid predicted key rate at index 0 for n_X = 100000
Objective returned NaN for params [0.56350285 0.28639567 0.19117592 0.77275324 0.7116522 ], L=0.2002002002002002, nx=100000
Skipping invalid predicted key rate at index 1 for n_X = 100000
Objective returned NaN for params [0.56338954 0.28627175 0.19090594 0.77275324 0.71119547], L=0.4004004004004004, nx=100000
Skipping invalid predicted key rate at index 2 for n_X = 100000
Objective returned NaN for params [0.5632763  0.28614783 0.19063605 0.77275324 0.71073884], L=0.6006006006006006, nx=100000
Skipping invalid predicted key rate at index 3 for n_X = 100000
Objective returned NaN for params [0.563163   0.2860239  0.19036613 0.77275324 0.71028215], L=0.8008008008008008, nx=100000
Skipping invalid predicted key rate 

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+05.png
--- Plotting Time: 1.20 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.54884994 0.2782615  0.19541948 0.77275324 0.6868042 ], L=0.0, nx=1000000
Skipping invalid predicted key rate at index 0 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.5484553  0.27782828 0.194509   0.77275324 0.6851319 ], L=0.6006006006006006, nx=1000000
Skipping invalid predicted key rate at index 3 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.5483238  0.27768385 0.19420548 0.77275324 0.6845746 ], L=0.8008008008008008, nx=1000000
Skipping invalid predicted key rate at index 4 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.54819226 0.2775395  0.193902   0.77275324 0.6840171 ], L=1.001001001001001, nx=1000000
Skipping invalid predicted key rate at index 5 for n_X = 1000000
Objective returned non-positive key rate 

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+06.png
--- Plotting Time: 1.12 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.5372783  0.27865764 0.20977019 0.77275324 0.69530016], L=0.2002002002002002, nx=10000000
Skipping invalid predicted key rate at index 1 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.53697413 0.27842483 0.20921704 0.77275324 0.6944593 ], L=0.6006006006006006, nx=10000000
Skipping invalid predicted key rate at index 3 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.536822   0.27830842 0.20894049 0.77275324 0.6940389 ], L=0.8008008008008008, nx=10000000
Skipping invalid predicted key rate at index 4 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.5366699  0.278192   0.20866397 0.77275324 0.6936185 ], L=1.001001001001001, nx=10000000
Skipping invalid predicted key rate at index 5 for n_X = 10000000
Objective returned

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+07.png
--- Plotting Time: 1.09 seconds ---
Objective returned NaN for params [0.5294149  0.28773382 0.2340158  0.7659842  0.73664945], L=0.0, nx=100000000
Skipping invalid predicted key rate at index 0 for n_X = 100000000
Objective returned NaN for params [0.5292946  0.2876232  0.23386113 0.76613885 0.7362847 ], L=0.2002002002002002, nx=100000000
Skipping invalid predicted key rate at index 1 for n_X = 100000000
Objective returned non-positive key rate (-2e+250) for params [0.5291743  0.2875126  0.23370644 0.7662936  0.73591983], L=0.4004004004004004, nx=100000000
Skipping invalid predicted key rate at index 2 for n_X = 100000000
Objective returned NaN for params [0.529054  0.287402  0.2335517 0.7664483 0.7355551], L=0.6006006006006006, nx=100000000
Skipping invalid predicted key rate at index 3 for n_X = 100000000
Objective returned NaN for params [0.52893364 0.28729138 0.2333969  0.7666031  0.7351903 ], L=0.8008008008008008,

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+08.png
--- Plotting Time: 1.23 seconds ---
Objective returned non-positive key rate (-2e+250) for params [0.52930164 0.29918763 0.25186238 0.74813765 0.7895695 ], L=0.0, nx=1000000000
Skipping invalid predicted key rate at index 0 for n_X = 1000000000
Objective returned non-positive key rate (-2e+250) for params [0.52919424 0.29907513 0.2517315  0.7482686  0.7892257 ], L=0.2002002002002002, nx=1000000000
Skipping invalid predicted key rate at index 1 for n_X = 1000000000
Objective returned NaN for params [0.5290869  0.2989626  0.2516005  0.74839944 0.78888184], L=0.4004004004004004, nx=1000000000
Skipping invalid predicted key rate at index 2 for n_X = 1000000000
Objective returned NaN for params [0.52897954 0.2988501  0.2514695  0.74853045 0.78853816], L=0.6006006006006006, nx=1000000000
Skipping invalid predicted key rate at index 3 for n_X = 1000000000
Objective returned NaN for params [0.52887213 0.29873756 0.2513385  0.74

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+09.png
--- Plotting Time: 1.25 seconds ---


  ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linestyle='-', linewidth=2.0, alpha=1.0)  # Thicker line, semi-transparent
  plt.show()


Key rate subplots saved to keyrate_subplots_first_epoch.png


  plt.show()


Parameter subplots saved to parameters_subplots_first_epoch.png
Model saved to bb84_nn_model.pth
Epoch 1/5000, Train Loss: 0.1335, Val Loss: 0.0457, Learning Rate: 0.001000, Time: 11.29s (Eval: 8.14s, Plot: 2.59s)
Epoch 2/5000, Train Loss: 0.0263, Val Loss: 0.0176, Learning Rate: 0.001000, Time: 0.27s (Eval: 0.00s, Plot: 0.00s)
Epoch 3/5000, Train Loss: 0.0126, Val Loss: 0.0101, Learning Rate: 0.001000, Time: 0.28s (Eval: 0.00s, Plot: 0.00s)
Epoch 4/5000, Train Loss: 0.0076, Val Loss: 0.0061, Learning Rate: 0.001000, Time: 0.24s (Eval: 0.00s, Plot: 0.00s)
Epoch 5/5000, Train Loss: 0.0046, Val Loss: 0.0037, Learning Rate: 0.001000, Time: 0.26s (Eval: 0.00s, Plot: 0.00s)
Epoch 6/5000, Train Loss: 0.0028, Val Loss: 0.0025, Learning Rate: 0.001000, Time: 0.26s (Eval: 0.00s, Plot: 0.00s)
Epoch 7/5000, Train Loss: 0.0020, Val Loss: 0.0019, Learning Rate: 0.001000, Time: 0.23s (Eval: 0.00s, Plot: 0.00s)
Epoch 8/5000, Train Loss: 0.0015, Val Loss: 0.0015, Learning Rate: 0.001000, Time: 0.23s (

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+04.png
--- Plotting Time: 1.16 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.5807696  0.21318452 0.09073634 0.6970598  0.5524227 ], L=0.0, nx=100000
Skipping invalid predicted key rate at index 0 for n_X = 100000
Objective returned non-positive key rate (-1e+250) for params [0.58066446 0.21314904 0.09074961 0.69703126 0.55239606], L=0.2002002002002002, nx=100000
Skipping invalid predicted key rate at index 1 for n_X = 100000
Objective returned non-positive key rate (-1e+250) for params [0.5805593  0.21311359 0.09076284 0.69700277 0.55236953], L=0.4004004004004004, nx=100000
Skipping invalid predicted key rate at index 2 for n_X = 100000
Objective returned non-positive key rate (-1e+250) for params [0.58045423 0.21307813 0.09077602 0.69697416 0.55234295], L=0.6006006006006006, nx=100000
Skipping invalid predicted key rate at index 3 for n_X = 100000
Objective returned non-positive key rate (-1e+25

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+05.png
--- Plotting Time: 1.29 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.551146   0.26722535 0.12419744 0.7380333  0.6609638 ], L=0.0, nx=1000000
Skipping invalid predicted key rate at index 0 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.55102104 0.2672019  0.12423263 0.73800564 0.6610273 ], L=0.2002002002002002, nx=1000000
Skipping invalid predicted key rate at index 1 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.5508962  0.26717842 0.1242678  0.737978   0.6610909 ], L=0.4004004004004004, nx=1000000
Skipping invalid predicted key rate at index 2 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.5507713  0.267155   0.12430288 0.7379503  0.6611547 ], L=0.6006006006006006, nx=1000000
Skipping invalid predicted key rate at index 3 for n_X = 1000000
Objective returned non-positive key rate

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+06.png
--- Plotting Time: 1.11 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.52378213 0.31071553 0.15107493 0.76325315 0.75449675], L=0.0, nx=10000000
Skipping invalid predicted key rate at index 0 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.52368957 0.31064945 0.15107803 0.7632125  0.75449294], L=0.2002002002002002, nx=10000000
Skipping invalid predicted key rate at index 1 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.5233195  0.31038514 0.1510901  0.76304996 0.7544777 ], L=1.001001001001001, nx=10000000
Skipping invalid predicted key rate at index 5 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.5231344  0.31025302 0.15109618 0.76296866 0.75447   ], L=1.4014014014014016, nx=10000000
Skipping invalid predicted key rate at index 7 for n_X = 10000000
Objective returned non-positive k

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+07.png
--- Plotting Time: 1.14 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.50507843 0.34072262 0.17028877 0.77275324 0.8223622 ], L=0.0, nx=100000000
Skipping invalid predicted key rate at index 0 for n_X = 100000000
Objective returned non-positive key rate (-1e+250) for params [0.50495577 0.34067386 0.17028947 0.77275324 0.8224087 ], L=0.2002002002002002, nx=100000000
Skipping invalid predicted key rate at index 1 for n_X = 100000000
Objective returned non-positive key rate (-1e+250) for params [0.50483316 0.34062514 0.17029022 0.77275324 0.82245517], L=0.4004004004004004, nx=100000000
Skipping invalid predicted key rate at index 2 for n_X = 100000000
Objective returned non-positive key rate (-1e+250) for params [0.5047105  0.34057632 0.17029099 0.77275324 0.8225016 ], L=0.6006006006006006, nx=100000000
Skipping invalid predicted key rate at index 3 for n_X = 100000000
Objective returned non-p

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+08.png
--- Plotting Time: 1.16 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.48908013 0.3656066  0.18723157 0.77275324 0.8765765 ], L=0.2002002002002002, nx=1000000000
Skipping invalid predicted key rate at index 1 for n_X = 1000000000
Objective returned non-positive key rate (-1e+250) for params [0.48894554 0.3655373  0.18722118 0.77275324 0.8765765 ], L=0.4004004004004004, nx=1000000000
Skipping invalid predicted key rate at index 2 for n_X = 1000000000
Objective returned non-positive key rate (-1e+250) for params [0.48882228 0.3654497  0.18721117 0.77275324 0.8765765 ], L=0.6006006006006006, nx=1000000000
Skipping invalid predicted key rate at index 3 for n_X = 1000000000
Objective returned non-positive key rate (-1e+250) for params [0.48869905 0.36536205 0.18720117 0.77275324 0.8765765 ], L=0.8008008008008008, nx=1000000000
Skipping invalid predicted key rate at index 4 for n_X = 1000000000
O

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to parameter_relative_error_nx_1e+09.png
--- Plotting Time: 1.21 seconds ---


  ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linestyle='-', linewidth=2.0, alpha=1.0)  # Thicker line, semi-transparent
  plt.show()


Key rate subplots saved to keyrate_subplots_last_epoch.png
Parameter subplots saved to parameters_subplots_last_epoch.png
Model saved to bb84_nn_model.pth
Epoch 5000/5000, Train Loss: 0.0002, Val Loss: 0.0002, Learning Rate: 0.000001, Time: 10.77s (Eval: 7.98s, Plot: 2.52s)
--- Total Training Time: 1220.91 seconds ---


  plt.show()


In [18]:
# %%
# Plotting the losses with smoothing
def moving_average(data, window_size=5):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

epochs = range(1, len(train_losses) + 1)
window_size = 5
smoothed_train_losses = moving_average(train_losses, window_size)
smoothed_val_losses = moving_average(val_losses, window_size)
smoothed_epochs = range(window_size, len(train_losses) + 1)

In [19]:
plt.figure(figsize=(12, 6))
# plt.plot(smoothed_epochs, smoothed_train_losses, label='Smoothed Training Loss', linestyle='-')
# plt.plot(smoothed_epochs, smoothed_val_losses, label='Smoothed Validation Loss', linestyle='--')
plt.plot(epochs, train_losses, label='Training Loss', linestyle='-', alpha=0.3)
plt.plot(epochs, val_losses, label='Validation Loss', linestyle='--', alpha=0.5)
plt.xlabel('Epochs')
plt.ylim(0, 0.002)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_plot.png', dpi=150)
plt.show()
# plt.close()

print("Training Complete")

# Plotting the learning rates
plt.figure(figsize=(12, 6))
plt.plot(epochs, learning_rates, label='Learning Rate', linestyle='-', color='tab:blue')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.legend()
plt.title('Learning Rate over Epochs')
plt.grid(True)
plt.tight_layout()
plt.savefig('learning_rate_plot.png', dpi=150)
plt.show()
# plt.close()

Training Complete


  plt.show()
  plt.show()


In [20]:
# %% [markdown]
# ### Test the Trained Model on n_X = 5e8 Dataset

# Load the saved model
model = BB84NN().to(device)
model.load_state_dict(torch.load('bb84_nn_model.pth'))
model.eval()
print("Loaded saved model for testing.")

# Evaluate the model on the n_X = 5e8 dataset
with torch.no_grad():
    predicted_params_scaled_un_seen = model(X_test_tensor_un_seen).cpu().numpy()
    predicted_params_scaled_un_seen = np.clip(predicted_params_scaled_un_seen, 0, 1)
    predicted_params_un_seen = y_scaler.inverse_transform(predicted_params_scaled_un_seen)
    print("Predicted parameters (mu_1, mu_2, P_mu_1, P_mu_2, P_X):")
    print(predicted_params_un_seen[:5])
    predicted_params_un_seen[:, 0] = np.maximum(predicted_params_un_seen[:, 0], 1e-6)
    predicted_params_un_seen[:, 1] = np.maximum(predicted_params_un_seen[:, 1], 1e-6)
    predicted_params_un_seen[:, 2] = np.clip(predicted_params_un_seen[:, 2], 0, 1)
    predicted_params_un_seen[:, 3] = np.clip(predicted_params_un_seen[:, 3], 0, 1)
    predicted_params_un_seen[:, 4] = np.clip(predicted_params_un_seen[:, 4], 0, 1)

    # Evaluate predicted key rates
    predicted_key_rates_un_seen = []
    valid_indices_pred = []
    for idx, (params, L) in enumerate(zip(predicted_params_un_seen, fiber_lengths_un_seen)):
        key_rate = safe_objective(params, L, target_nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, 
                                  epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
        if key_rate is None:
            print(f"Skipping invalid predicted key rate at index {idx} for n_X = {target_nx}")
            continue
        predicted_key_rates_un_seen.append(key_rate)
        valid_indices_pred.append(idx)
    predicted_key_rates_un_seen = np.array(predicted_key_rates_un_seen)

    # Further filter based on valid predicted key rates
    fiber_lengths_un_seen = fiber_lengths_un_seen[valid_indices_pred]
    optimized_key_rates_un_seen = optimized_key_rates_un_seen[valid_indices_pred]
    optimized_params_array_un_seen = optimized_params_array_un_seen[valid_indices_pred]
    predicted_params_un_seen = predicted_params_un_seen[valid_indices_pred]
    print(f"After filtering invalid key rates in predicted_key_rates_un_seen, {len(predicted_key_rates_un_seen)} data points remain for n_X = 5e8.")

# Debug: Check for NaN in optimized_key_rates_un_seen and predicted_key_rates_un_seen
print("Checking for NaN in optimized_key_rates_un_seen:")
print(f"Number of NaN values: {np.isnan(optimized_key_rates_un_seen).sum()}")
print("Checking for NaN in predicted_key_rates_un_seen:")
print(f"Number of NaN values: {np.isnan(predicted_key_rates_un_seen).sum()}")

Loaded saved model for testing.
Predicted parameters (mu_1, mu_2, P_mu_1, P_mu_2, P_X):
[[0.49388713 0.35793784 0.18170445 0.77275324 0.8624248 ]
 [0.4926497  0.3574458  0.18171202 0.77275324 0.8628936 ]
 [0.49141762 0.35697138 0.1817677  0.77275324 0.86311513]
 [0.49019083 0.35651496 0.18187256 0.77275324 0.86308485]
 [0.48896402 0.35605857 0.18197748 0.77275324 0.8630543 ]]
Objective returned non-positive key rate (-1e+250) for params [0.49388713 0.35793784 0.18170445 0.77275324 0.8624248 ], L=0.0, nx=500000000
Skipping invalid predicted key rate at index 0 for n_X = 500000000
Objective returned non-positive key rate (-1e+250) for params [0.4926497  0.3574458  0.18171202 0.77275324 0.8628936 ], L=2.0202020202020203, nx=500000000
Skipping invalid predicted key rate at index 1 for n_X = 500000000
Objective returned non-positive key rate (-1e+250) for params [0.49019083 0.35651496 0.18187256 0.77275324 0.86308485], L=6.0606060606060606, nx=500000000
Skipping invalid predicted key rate a

In [21]:
def plot_keyrate_and_parameters(fiber_lengths, optimized_key_rates, optimized_params_array, 
                                predicted_key_rates, predicted_params_array, epoch, filename, 
                                learning_rates=None, nx=None):
    """
    Plots key rates and parameters for a specific n_X value.
    
    Parameters:
    - fiber_lengths: Array of fiber lengths.
    - optimized_key_rates: Array of optimized key rates.
    - optimized_params_array: Array of optimized parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - predicted_key_rates: Array of predicted key rates.
    - predicted_params_array: Array of predicted parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - epoch: Epoch number or string (e.g., 'final_test').
    - filename: Output filename for the plot.
    - learning_rates: List of learning rates over epochs (optional, not used).
    - nx: The n_X value for the plot title (optional).
    """

    # Start timing
    plot_start_time = time.time()

    # Create a figure with 2 subplots: key rate (left), all parameters (right)
    fig = plt.figure(figsize=(15, 6))  # Adjusted height for horizontal layout

    # Define the grid layout: 1 row, 2 columns
    gs = fig.add_gridspec(1, 2, width_ratios=[1, 1])

    # Key Rate Plot (Left)
    ax_keyrate = fig.add_subplot(gs[0, 0])

    # Filter out NaN values
    valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
    fiber_lengths = fiber_lengths[valid_mask]
    optimized_key_rates = optimized_key_rates[valid_mask]
    predicted_key_rates = predicted_key_rates[valid_mask]
    optimized_params_array = optimized_params_array[valid_mask]
    predicted_params_array = predicted_params_array[valid_mask]

    # Find cutoff where optimized key rate becomes very small
    threshold = 1e-8
    cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
    if len(cutoff_idx) > 0:
        cutoff_idx = cutoff_idx[0]
    else:
        cutoff_idx = len(fiber_lengths)

    fiber_lengths = fiber_lengths[:cutoff_idx]
    optimized_key_rates = optimized_key_rates[:cutoff_idx]
    predicted_key_rates = predicted_key_rates[:cutoff_idx]
    optimized_params_array = optimized_params_array[:cutoff_idx]
    predicted_params_array = predicted_params_array[:cutoff_idx]

    # Plot key rates with distinct linestyles and increased thickness for predicted
    ax_keyrate.plot(fiber_lengths, np.log10(predicted_key_rates), 'r--', label='Predicted Key Rate', linewidth=4.0, alpha=0.7)
    ax_keyrate.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linewidth=2.0, alpha=1.0)

    # Use scientific notation for nx in the title
    ax_keyrate.set_title(f"Key Rates for $n_X = 5 \\times 10^8$")  # Fixed title
    ax_keyrate.set_xlabel("Fiber Length (km)")
    ax_keyrate.set_ylabel("Secret Key Rate per Pulse")
    ax_keyrate.grid(True)
    ax_keyrate.legend(loc='upper right')

    # Parameter Plot (Right) - Combine probabilities and intensities
    ax_params = fig.add_subplot(gs[0, 1])

    param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
    colors = ['red', 'purple', 'orange', 'green', 'blue']
    param_indices = [0, 1, 2, 3, 4]  # Indices for mu_1, mu_2, P_mu_1, P_mu_2, P_X  

    # Plot all parameters
    for param_idx, (label, color) in zip(param_indices, zip(param_labels, colors)):
        ax_params.plot(fiber_lengths, predicted_params_array[:, param_idx], label=f'Predicted {label}', color=color, linestyle='--', linewidth=4.0, alpha=0.7)
        ax_params.plot(fiber_lengths, optimized_params_array[:, param_idx], label=f'Optimized {label}', color=color, linestyle='-', linewidth=2.0, alpha=1.0)
        
    # Use scientific notation for nx in the title
    ax_params.set_title(f"Parameters for $n_X = 5 \\times 10^8$")  # Fixed title
    ax_params.set_xlabel("Fiber Length (km)")
    ax_params.set_ylabel("Parameter Value")
    ax_params.set_ylim(0.0, 1.0)
    ax_params.grid(True)

    # Move legend outside the plot to avoid overlap
    ax_params.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    # Adjust layout to prevent overlap
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close(fig)
    print(f"Comparison plot saved to {filename}")

In [22]:
for nx in all_evaluation_data.keys():
    eval_data = all_evaluation_data[nx]
    fiber_lengths = eval_data['fiber_lengths']
    optimized_params_array = eval_data['optimized_params_array']
    optimized_key_rates = eval_data['optimized_key_rates']
    X_test_tensor = eval_data['X_test_tensor']

    with torch.no_grad():
        predicted_params_scaled = model(X_test_tensor).cpu().numpy()
        predicted_params_scaled = np.clip(predicted_params_scaled, 0, 1)
        predicted_params_array = y_scaler.inverse_transform(predicted_params_scaled)
        predicted_params_array[:, 0] = np.maximum(predicted_params_array[:, 0], 1e-6)
        predicted_params_array[:, 1] = np.maximum(predicted_params_array[:, 1], 1e-6)
        predicted_params_array[:, 2] = np.clip(predicted_params_array[:, 2], 0, 1)
        predicted_params_array[:, 3] = np.clip(predicted_params_array[:, 3], 0, 1)
        predicted_params_array[:, 4] = np.clip(predicted_params_array[:, 4], 0, 1)

        # Compute and filter predicted key rates
        predicted_key_rates = []
        valid_indices = []
        for idx, (params, L) in enumerate(zip(predicted_params_array, fiber_lengths)):
            key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, 
                                      epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, 
                                      e_mis=5e-3, P_ap=0, n_event=1)
            if key_rate is None:
                print(f"Skipping invalid predicted key rate at index {idx} for n_X = {nx}")
                continue
            predicted_key_rates.append(key_rate)
            valid_indices.append(idx)
        predicted_key_rates = np.array(predicted_key_rates)

        # Filter data based on valid predicted key rates
        fiber_lengths = fiber_lengths[valid_indices]
        optimized_key_rates = optimized_key_rates[valid_indices]
        optimized_params_array = optimized_params_array[valid_indices]
        predicted_params_array = predicted_params_array[valid_indices]

    # Plot relative errors for this n_X
    plot_start_time = time.time()
    plot_relative_errors(
        fiber_lengths, optimized_key_rates, optimized_params_array,
        predicted_key_rates, predicted_params_array, epoch='final_test',
        filename=f'relative_error_nx_{nx:.0e}.png', nx=nx
    )
    print(f"--- Final Test Plotting Time (relative_error_nx_{nx:.0e}): {time.time() - plot_start_time:.2f} seconds ---")

Objective returned non-positive key rate (-1e+250) for params [0.60711163 0.13833332 0.04560989 0.61332816 0.41392168], L=0.0, nx=10000
Skipping invalid predicted key rate at index 0 for n_X = 10000
Objective returned non-positive key rate (-1e+250) for params [0.6070091  0.13828968 0.04560989 0.6132862  0.41391465], L=0.2002002002002002, nx=10000
Skipping invalid predicted key rate at index 1 for n_X = 10000
Objective returned non-positive key rate (-1e+250) for params [0.60690296 0.1382514  0.04560989 0.61325735 0.4139326 ], L=0.4004004004004004, nx=10000
Skipping invalid predicted key rate at index 2 for n_X = 10000
Objective returned non-positive key rate (-1e+250) for params [0.6067965  0.13821334 0.04560989 0.613229   0.4139515 ], L=0.6006006006006006, nx=10000
Skipping invalid predicted key rate at index 3 for n_X = 10000
Objective returned non-positive key rate (-1e+250) for params [0.60669017 0.1381753  0.04560989 0.6132006  0.41397035], L=0.8008008008008008, nx=10000
Skipping

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+04.png
--- Plotting Time: 1.10 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+04): 1.10 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.5807696  0.21318452 0.09073634 0.6970598  0.5524227 ], L=0.0, nx=100000
Skipping invalid predicted key rate at index 0 for n_X = 100000
Objective returned non-positive key rate (-1e+250) for params [0.58066446 0.21314904 0.09074961 0.69703126 0.55239606], L=0.2002002002002002, nx=100000
Skipping invalid predicted key rate at index 1 for n_X = 100000
Objective returned non-positive key rate (-1e+250) for params [0.5805593  0.21311359 0.09076284 0.69700277 0.55236953], L=0.4004004004004004, nx=100000
Skipping invalid predicted key rate at index 2 for n_X = 100000
Objective returned non-positive key rate (-1e+250) for params [0.58045423 0.21307813 0.09077602 0.69697416 0.55234295], L=0.6006006006006006, nx=100000
Skipping invalid predicted key rate at index 3 fo

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+05.png
--- Plotting Time: 1.20 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+05): 1.20 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.551146   0.26722535 0.12419744 0.7380333  0.6609638 ], L=0.0, nx=1000000
Skipping invalid predicted key rate at index 0 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.55102104 0.2672019  0.12423263 0.73800564 0.6610273 ], L=0.2002002002002002, nx=1000000
Skipping invalid predicted key rate at index 1 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.5508962  0.26717842 0.1242678  0.737978   0.6610909 ], L=0.4004004004004004, nx=1000000
Skipping invalid predicted key rate at index 2 for n_X = 1000000
Objective returned non-positive key rate (-1e+250) for params [0.5507713  0.267155   0.12430288 0.7379503  0.6611547 ], L=0.6006006006006006, nx=1000000
Skipping invalid predicted key rate at ind

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+06.png
--- Plotting Time: 1.11 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+06): 1.11 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.52378213 0.31071553 0.15107493 0.76325315 0.75449675], L=0.0, nx=10000000
Skipping invalid predicted key rate at index 0 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.52368957 0.31064945 0.15107803 0.7632125  0.75449294], L=0.2002002002002002, nx=10000000
Skipping invalid predicted key rate at index 1 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.5233195  0.31038514 0.1510901  0.76304996 0.7544777 ], L=1.001001001001001, nx=10000000
Skipping invalid predicted key rate at index 5 for n_X = 10000000
Objective returned non-positive key rate (-1e+250) for params [0.5231344  0.31025302 0.15109618 0.76296866 0.75447   ], L=1.4014014014014016, nx=10000000
Skipping invalid predicted key rate 

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+07.png
--- Plotting Time: 1.10 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+07): 1.10 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.50507843 0.34072262 0.17028877 0.77275324 0.8223622 ], L=0.0, nx=100000000
Skipping invalid predicted key rate at index 0 for n_X = 100000000
Objective returned non-positive key rate (-1e+250) for params [0.50495577 0.34067386 0.17028947 0.77275324 0.8224087 ], L=0.2002002002002002, nx=100000000
Skipping invalid predicted key rate at index 1 for n_X = 100000000
Objective returned non-positive key rate (-1e+250) for params [0.50483316 0.34062514 0.17029022 0.77275324 0.82245517], L=0.4004004004004004, nx=100000000
Skipping invalid predicted key rate at index 2 for n_X = 100000000
Objective returned non-positive key rate (-1e+250) for params [0.5047105  0.34057632 0.17029099 0.77275324 0.8225016 ], L=0.6006006006006006, nx=100000000
Skipping invalid predicted k

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+08.png
--- Plotting Time: 1.11 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+08): 1.11 seconds ---
Objective returned non-positive key rate (-1e+250) for params [0.48908013 0.3656066  0.18723157 0.77275324 0.8765765 ], L=0.2002002002002002, nx=1000000000
Skipping invalid predicted key rate at index 1 for n_X = 1000000000
Objective returned non-positive key rate (-1e+250) for params [0.48894554 0.3655373  0.18722118 0.77275324 0.8765765 ], L=0.4004004004004004, nx=1000000000
Skipping invalid predicted key rate at index 2 for n_X = 1000000000
Objective returned non-positive key rate (-1e+250) for params [0.48882228 0.3654497  0.18721117 0.77275324 0.8765765 ], L=0.6006006006006006, nx=1000000000
Skipping invalid predicted key rate at index 3 for n_X = 1000000000
Objective returned non-positive key rate (-1e+250) for params [0.48869905 0.36536205 0.18720117 0.77275324 0.8765765 ], L=0.8008008008008008, nx=1000000000
Skippi

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+09.png
--- Plotting Time: 1.08 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+09): 1.08 seconds ---


In [23]:
# Compute Mean Absolute Error (MAE) for parameters and key rates
mae_params = mean_absolute_error(optimized_params_array_un_seen, predicted_params_un_seen, multioutput='raw_values')
param_labels = ['$mu_1$', '$mu_2$', '$P_{mu_1}$', '$P_{mu_2}$', '$P_X$']
for label, mae in zip(param_labels, mae_params):
    print(f"MAE for {label}: {mae:.6f}")

# Compute MAE for key rates, handling NaN values
valid_mask = ~(np.isnan(optimized_key_rates_un_seen) | np.isnan(predicted_key_rates_un_seen))
if not np.any(valid_mask):
    print("Warning: All key rates are NaN. Cannot compute MAE for key rates.")
    mae_key_rates = float('nan')
else:
    mae_key_rates = mean_absolute_error(
        optimized_key_rates_un_seen[valid_mask],
        predicted_key_rates_un_seen[valid_mask]
    )
    print(f"Computed MAE for Key Rates using {np.sum(valid_mask)} valid data points.")
print(f"MAE for Key Rates: {mae_key_rates:.6f}")

MAE for $mu_1$: 0.001394
MAE for $mu_2$: 0.001606
MAE for $P_{mu_1}$: 0.006342
MAE for $P_{mu_2}$: 0.007561
MAE for $P_X$: 0.003187
Computed MAE for Key Rates using 26 valid data points.
MAE for Key Rates: 0.000001


In [24]:
def plot_relative_errors(fiber_lengths, optimized_key_rates, optimized_params_array, 
                        predicted_key_rates, predicted_params_array, epoch, filename, 
                        nx=None, threshold=1e-8):
    """
    Plots relative errors of predicted vs. optimized parameters and key rate for a specific n_X.
    
    Parameters:
    - fiber_lengths: Array of fiber lengths (km).
    - optimized_key_rates: Array of optimized key rates.
    - optimized_params_array: Array of optimized parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - predicted_key_rates: Array of predicted key rates.
    - predicted_params_array: Array of predicted parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - epoch: Epoch number or string (e.g., 'final_test').
    - filename: Output filename for the plot.
    - nx: The n_X value for the plot title (optional).
    - threshold: Key rate threshold for physical cutoff (default: 1e-8).
    """
    # Start timing
    plot_start_time = time.time()

    # Filter out NaN values
    valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
    fiber_lengths = fiber_lengths[valid_mask]
    optimized_key_rates = optimized_key_rates[valid_mask]
    predicted_key_rates = predicted_key_rates[valid_mask]
    optimized_params_array = optimized_params_array[valid_mask]
    predicted_params_array = predicted_params_array[valid_mask]

    # Find cutoff where optimized key rate becomes very small
    cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
    if len(cutoff_idx) > 0:
        cutoff_idx = cutoff_idx[0]
    else:
        cutoff_idx = len(fiber_lengths)

    fiber_lengths = fiber_lengths[:cutoff_idx]
    optimized_key_rates = optimized_key_rates[:cutoff_idx]
    predicted_key_rates = predicted_key_rates[:cutoff_idx]
    optimized_params_array = optimized_params_array[:cutoff_idx]
    predicted_params_array = predicted_params_array[:cutoff_idx]

    # Compute relative errors
    # Parameters: mu_1, mu_2, P_mu_1, P_mu_2, P_X
    relative_errors = []
    param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
    for i in range(5):
        optimized = optimized_params_array[:, i]
        predicted = predicted_params_array[:, i]
        # Avoid division by zero with a small threshold
        denominator = np.maximum(optimized, 1e-10)
        rel_error = (predicted - optimized) / denominator
        relative_errors.append(rel_error)
    
    # Key rate relative error
    denominator = np.maximum(optimized_key_rates, 1e-10)
    key_rate_rel_error = (predicted_key_rates - optimized_key_rates) / denominator

    # Create figure with 6 subplots (2 rows, 3 columns)
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    # Plot relative errors
    for i in range(5):
        row = i // 3
        col = i % 3
        ax = fig.add_subplot(gs[row, col])
        ax.plot(fiber_lengths, relative_errors[i], 'b-', label=f'Relative Error {param_labels[i]}')
        ax.set_xlabel('Fiber Length (km)')
        ax.set_ylabel('Relative Error')
        ax.set_title(f'Relative Error for {param_labels[i]}')
        ax.grid(True)
        ax.legend(loc='best')

    # Key rate subplot
    ax_key = fig.add_subplot(gs[1, 2])
    ax_key.plot(fiber_lengths, key_rate_rel_error, 'r-', label='Relative Error Key Rate')
    ax_key.set_xlabel('Fiber Length (km)')
    ax_key.set_ylabel('Relative Error')
    ax_key.set_title('Relative Error for Key Rate')
    ax_key.grid(True)
    ax_key.legend(loc='best')

    # Overall title
    exponent = int(np.log10(nx)) if nx is not None else ''
        # Example title update (adjust based on actual function structure)
    fig.suptitle(f"Relative Errors of Parameters for $n_X = 5 \\times 10^8$", fontsize=16)

    # Save and close
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Relative error plot saved to {filename}")

    # Print plotting time
    plot_time = time.time() - plot_start_time
    print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [25]:
# Plot the comparison for n_X = 5e8
plot_start_time = time.time()
plot_keyrate_and_parameters(
    fiber_lengths_un_seen, optimized_key_rates_un_seen, optimized_params_array_un_seen,
    predicted_key_rates_un_seen, predicted_params_un_seen, epoch='final_test',
    filename='keyrate_parameters_5e8.png', learning_rates=learning_rates, nx=target_nx
)

print(f"--- Final Test Plotting Time (keyrate_and_parameters): {time.time() - plot_start_time:.2f} seconds ---")

# Plot relative errors for n_X = 5e8
plot_relative_errors(
    fiber_lengths_un_seen, optimized_key_rates_un_seen, optimized_params_array_un_seen,
    predicted_key_rates_un_seen, predicted_params_un_seen, epoch='final_test',
    filename=f'relative_error_nx_{target_nx:.0e}.png', nx=target_nx
)

  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Comparison plot saved to keyrate_parameters_5e8.png
--- Final Test Plotting Time (keyrate_and_parameters): 0.82 seconds ---
Relative error plot saved to relative_error_nx_5e+08.png
--- Plotting Time: 1.14 seconds ---
