In [1]:
# %% [markdown]
# Import necessary libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Subset
import torch.optim as optim
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_absolute_error
import json
from tqdm import tqdm
import torch.nn.functional as F
import joblib
import os
import sys
import numpy as np
from torch.utils.data import DataLoader, random_split
import time  # For timing

# Set matplotlib to use a non-interactive backend
plt.switch_backend('Agg')

# Get the notebook's directory
notebook_dir = os.getcwd()
# Add parent directory to path
project_root = os.path.dirname(notebook_dir)
sys.path.append(project_root)

# %% [markdown]
# Checking if the MPS (Metal Performance Shaders) backend is available

# %%
# Check if MPS is available
if not torch.backends.mps.is_available():
    raise RuntimeError("MPS device not available. Check if PyTorch and macOS set up correctly.")

# Set the device to MPS
device = torch.device("mps")  # Use GPU on M2 Pro

In [2]:
# %%
# Import the objective function
from QKD_Functions.QKD_Functions import (
    calculate_factorial, calculate_tau_n, calculate_eta_ch, calculate_eta_sys,
    calculate_D_mu_k, calculate_n_X_total, calculate_N, calculate_n_Z_total,
    calculate_e_mu_k, calculate_e_obs, calculate_h, calculate_lambda_EC,
    calculate_sqrt_term, calculate_n_pm, calculate_S_0, calculate_S_1,
    calculate_m_mu_k, calculate_m_pm, calculate_v_1, calculate_gamma,
    calculate_Phi, calculate_LastTwoTerm, calculate_l, calculate_R,
    experimental_parameters, other_parameters, calculate_key_rates_and_metrics,
    penalty, objective,
)

  from tqdm.autonotebook import tqdm


In [3]:
# Define a safe wrapper for the objective function
def safe_objective(params, L, nx, **kwargs):
    try:
        key_rate = objective(params, L, nx, **kwargs)[0]
        if np.isnan(key_rate) or np.isinf(key_rate):
            print(f"objective returned NaN/Inf for params {params}, L={L}, nx={nx}: {key_rate}")
            return 0
        if key_rate < 0:
            print(f"objective returned negative key rate {key_rate} for params {params}, L={L}, nx={nx}")
            return 0
        return key_rate
    except Exception as e:
        print(f"Error in objective for params {params}, L={L}, nx={nx}: {e}")
        return 0 # Replace with 1e-8 on error

# Define a wrapper for the objective function to get raw key rates
# Define a wrapper for the objective function to get raw key rates
# def safe_objective(params, L, nx, **kwargs):
#     try:
#         key_rate = objective(params, L, nx, **kwargs)[0]
#         if np.isnan(key_rate) or np.isinf(key_rate):
#             print(f"objective returned NaN/Inf for params {params}, L={L}, nx={nx}: {key_rate}")
#         elif key_rate < 0:
#             print(f"objective returned negative key rate for params {params}, L={L}, nx={nx}: {key_rate}")
#         return key_rate  # Return the raw key rate
#     except Exception as e:
#         print(f"Error in objective for params {params}, L={L}, nx={nx}: {e}")
#         return float('nan')  # Return NaN on error to propagate the issue
    
# %%
# Load dataset
with open('../Training_Data/n_X/good/cleaned_combined_datasets.json', 'r') as f:
    data_by_nx = json.load(f)

print(f"The overall dataset contains {len(data_by_nx)} entries (number of unique n_X values).")

first_key = list(data_by_nx.keys())[0]
print(f"The number of entries associated with the first key ({first_key}) is: {len(data_by_nx[first_key])}")


# Flatten the data structure and filter
cleaned_data = []
for n_x, entries in data_by_nx.items():
    cleaned_data.extend([
        item for item in entries
    if item["key_rate"] > 0 and item["e_1"] * 100 <= 200
])

# Optional: Verify the cleaned dataset
if not cleaned_data:
    print("No valid data after filtering.")
else:
    print(f"Filtered dataset contains {len(cleaned_data)} entries.")
    print("\nSample entry from the cleaned dataset:")
    print(json.dumps(cleaned_data[0], indent=2))
    print("\nNumber of unique n_X values:", len(data_by_nx))


# # Print the number of entries for each n_X before filtering
# for n_x in data_by_nx.keys():
#     print(f"Number of entries for n_X = {n_x}: {len(data_by_nx[n_x])}")


# # Define a safe wrapper for the objective function
# def safe_objective(params, L, nx, **kwargs):
#     try:
#         # Assuming objective returns (key_rate, other_metrics...) or just key_rate
#         result = objective(params, L, nx, **kwargs)
#         # Extract key_rate, handle tuple or single value returns
#         key_rate = result[0] if isinstance(result, (tuple, list)) and len(result) > 0 else result

#         # Check for NaN or negative rates (non-physical)
#         if not isinstance(key_rate, (int, float, np.number)) or np.isnan(key_rate) or key_rate < 0:
#             # Optional: Print only if you want to see these specific cases during debug
#             # print(f"objective returned invalid rate {key_rate} for params {params}, L={L}, nx={nx}")
#             return 0.0  # Return 0 for invalid rates
#         return float(key_rate) # Ensure it's a standard float

#     except (ValueError, TypeError, ZeroDivisionError, FloatingPointError) as e: # Catch common calculation errors
#         # Optional: Print calculation errors if needed during debug
#         # print(f"Calculation Error in objective for params {params}, L={L}, nx={nx}: {type(e).__name__} - {e}")
#         return 0.0 # Return 0 on calculation error

#     except Exception as e: # Catch any unexpected errors
#         print(f"Unexpected Error in objective for params {params}, L={L}, nx={nx}: {type(e).__name__} - {e}")
#         # You might want to print the traceback here for unexpected errors during debugging
#         # import traceback
#         # traceback.print_exc()
#         return 0.0 # Return 0 on unexpected error

# # Define a worker function suitable for starmap
# # (Ensure this matches the arguments used in your starmap calls)
# def safe_objective_starmap_worker(params, L, nx,
#                                 alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7,
#                                 epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16,
#                                 e_mis=5e-3, P_ap=0, n_event=1):
#     """Calls safe_objective with arguments suitable for starmap."""
#     return safe_objective(params, L, nx,
#                           alpha=alpha, eta_Bob=eta_Bob, P_dc_value=P_dc_value,
#                           epsilon_sec=epsilon_sec, epsilon_cor=epsilon_cor, f_EC=f_EC,
#                           e_mis=e_mis, P_ap=P_ap, n_event=n_event)

The overall dataset contains 6 entries (number of unique n_X values).
The number of entries associated with the first key (10000.0) is: 736
Filtered dataset contains 5310 entries.

Sample entry from the cleaned dataset:
{
  "fiber_length": 0.0,
  "e_1": 0.0,
  "e_2": 6.221848749616356,
  "e_3": 0.5,
  "e_4": 4.0,
  "key_rate": 0.00037995895580899045,
  "optimized_params": {
    "mu_1": 0.6126301275287198,
    "mu_2": 0.13818683663411718,
    "P_mu_1": 0.04604299826386499,
    "P_mu_2": 0.6110117432731096,
    "P_X_value": 0.414731486913827
  }
}

Number of unique n_X values: 6


In [4]:
# def objective(params, L, nx, alpha, eta_Bob, P_dc_value, epsilon_sec, epsilon_cor, f_EC, e_mis, P_ap, n_event):
#     mu_1, mu_2, P_mu_1, P_mu_2, P_X = params
#     # Existing computations
#     eta_ch = calculate_eta_ch(L, alpha)
#     eta_sys = calculate_eta_sys(eta_ch, eta_Bob)
#     n_X_total = calculate_n_X_total(nx, P_X, eta_sys, P_dc_value)
#     e_obs = calculate_e_obs(...)  # Replace with actual computation
#     print(f"Debug for params={params}, L={L}, nx={nx}:")
#     print(f"eta_ch: {eta_ch}, eta_sys: {eta_sys}, n_X_total: {n_X_total}, e_obs: {e_obs}")
#     key_rate = calculate_R(...)  # Replace with actual computation
#     print(f"Raw key rate: {key_rate}")
#     return key_rate, None

In [5]:
# # Flatten the data structure and filter
# all_data = []
# for n_x, entries in data_by_nx.items():
#     for item in entries:
#         item['n_X'] = float(n_x)
#         all_data.append(item)

# # Filter the data
# filtered_data = [item for item in all_data if item["key_rate"] > 0 and item["e_1"] * 100 <= 200]
# print(f"Filtered dataset contains {len(filtered_data)} entries.")

# # Split into train+val (85%) and test (15%)
# train_val_data, test_data = train_test_split(filtered_data, test_size=0.15, random_state=42)
# print(f"Train+Val dataset: {len(train_val_data)} entries, Test dataset: {len(test_data)} entries.")

# # Further split train+val into train (70/85 ≈ 82.35%) and val (15/85 ≈ 17.65%)
# train_data, val_data = train_test_split(train_val_data, test_size=0.1765, random_state=42)
# print(f"Train dataset: {len(train_data)} entries, Val dataset: {len(val_data)} entries.")

# # Prepare training and validation datasets
# X_train = np.array([[item['e_1'], item['e_2'], item['e_3'], item['e_4']] for item in train_data])
# Y_train = np.array([[item['optimized_params']['mu_1'], item['optimized_params']['mu_2'],
#                      item['optimized_params']['P_mu_1'], item['optimized_params']['P_mu_2'],
#                      item['optimized_params']['P_X_value']] for item in train_data])

# X_val = np.array([[item['e_1'], item['e_2'], item['e_3'], item['e_4']] for item in val_data])
# Y_val = np.array([[item['optimized_params']['mu_1'], item['optimized_params']['mu_2'],
#                    item['optimized_params']['P_mu_1'], item['optimized_params']['P_mu_2'],
#                    item['optimized_params']['P_X_value']] for item in val_data])

# # Scale the data
# scaler = StandardScaler()
# X_train = scaler.fit_transform(X_train)
# X_val = scaler.transform(X_val)

# y_scaler = MinMaxScaler()
# Y_train = y_scaler.fit_transform(Y_train)
# Y_val = y_scaler.transform(Y_val)

# # Save the scalers
# joblib.dump(scaler, 'scaler.pkl')
# joblib.dump(y_scaler, 'y_scaler.pkl')

# # Create datasets and data loaders
# train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(Y_train, dtype=torch.float32))
# val_dataset = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(Y_val, dtype=torch.float32))

# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
# # Flatten the data structure and filter
# cleaned_data = []
# for n_x, entries in data_by_nx.items():
#     filtered_entries = [item for item in entries if item["key_rate"] > 0 and item["e_1"] * 100 <= 200]
#     print(f"After filtering, n_X = {n_x} has {len(filtered_entries)} entries.")
#     cleaned_data.extend(filtered_entries)

# # Verify the cleaned dataset
# if not cleaned_data:
#     print("No valid data after filtering.")
# else:
#     print(f"Filtered dataset contains {len(cleaned_data)} entries.")
#     print("\nSample entry from the cleaned dataset:")
#     print(json.dumps(cleaned_data[0], indent=2))
#     print("\nNumber of unique n_X values:", len(data_by_nx))

X = np.array([[item['e_1'], item['e_2'], item['e_3'], item['e_4']] for item in cleaned_data])
Y = np.array([[item['optimized_params']['mu_1'], item['optimized_params']['mu_2'], item['optimized_params']['P_mu_1'], item['optimized_params']['P_mu_2'], item['optimized_params']['P_X_value']] for item in cleaned_data])

# # %%
# X = np.array([[item['e_1'], item['e_2'], item['e_3'], item['e_4']] for item in cleaned_data])
# Y = np.array([[item['optimized_params']['mu_1'], item['optimized_params']['mu_2'], item['optimized_params']['P_mu_1'], item['optimized_params']['P_mu_2'], item['optimized_params']['P_X_value']] for item in cleaned_data])

In [7]:
from sklearn.utils import shuffle

X, Y = shuffle(X, Y, random_state=42)

scaler = StandardScaler()
X = scaler.fit_transform(X)  # Fit and transform on training data

y_scaler = MinMaxScaler()  # Scale targets to [0, 1]
Y = y_scaler.fit_transform(Y)

# Save the scalers
joblib.dump(scaler, 'scaler.pkl')  # Save StandardScaler
joblib.dump(y_scaler, 'y_scaler.pkl')  # Save MinMaxScaler

print(f"X shape: {X.shape}, Y shape: {Y.shape}")
dataset = TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.float32))

# %%
# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Data loaders with increased batch size
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

X shape: (5310, 4), Y shape: (5310, 5)


In [8]:
# %%
# Load evaluation data for n_X = 5e8
with open("../n_X/good/5e8_100_reordered_qkd_grouped_dataset_20250226_191349.json", 'r') as f:
    evaluation_dataset = json.load(f)

# Select target n_X for evaluation
target_nx = 5 * 100000000  # 5e8
nx_key = str(float(target_nx))
if nx_key not in evaluation_dataset:
    raise ValueError(f"No data found for n_X = {target_nx}")
evaluation_data = evaluation_dataset[nx_key]

# Verify the size of the evaluation dataset
print(f"Number of data points for n_X = {target_nx}: {len(evaluation_data)}")

Number of data points for n_X = 500000000: 100


In [9]:
# Extract evaluation fiber lengths and optimized parameters
fiber_lengths_un_seen = np.array([entry["fiber_length"] for entry in evaluation_data])
optimized_params_array_un_seen = np.array([list(entry["optimized_params"].values()) for entry in evaluation_data])

# Compute optimized key rates for evaluation
optimized_key_rates_un_seen = []
valid_indices_un_seen = []
for idx, (params, L) in enumerate(zip(optimized_params_array_un_seen, fiber_lengths_un_seen)):
    key_rate = safe_objective(params, L, target_nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
    if np.isnan(key_rate):
        print(f"NaN detected in optimized_key_rates_un_seen at index {idx} with params {params}, fiber_length={L}")
    else:
        optimized_key_rates_un_seen.append(key_rate)
        valid_indices_un_seen.append(idx)
optimized_key_rates_un_seen = np.array(optimized_key_rates_un_seen)

# Filter the evaluation data to exclude NaN key rates
fiber_lengths_un_seen = fiber_lengths_un_seen[valid_indices_un_seen]
optimized_params_array_un_seen = optimized_params_array_un_seen[valid_indices_un_seen]

objective returned negative key rate -2.5459075744215506e-31 for params [5.21125578e-01 3.37384994e-01 1.92634353e-01 6.63224253e-01
 1.00000000e-12], L=189.89898989898992, nx=500000000
objective returned negative key rate -1.0847670075212521e-30 for params [9.00000000e-01 4.09777073e-01 1.21908112e-01 6.52515070e-01
 1.00000000e-12], L=191.91919191919195, nx=500000000
objective returned negative key rate -1.0756742243147874e-30 for params [8.13544868e-01 2.57789157e-01 3.75126595e-01 1.07065241e-01
 1.00000000e-12], L=193.93939393939394, nx=500000000
objective returned negative key rate -5.9934352286340745e-31 for params [5.53278979e-01 2.55719551e-01 4.01755354e-01 4.61925236e-01
 1.00000000e-12], L=195.95959595959596, nx=500000000
objective returned negative key rate -8.89190788063512e-31 for params [7.59296721e-01 3.16646804e-02 3.70889007e-01 2.18637344e-01
 1.00000000e-12], L=197.97979797979798, nx=500000000
objective returned negative key rate -1.1490135538073274e-30 for params 

In [10]:
# Prepare test inputs for evaluation (n_X = 5e8)
X_test_un_seen = []
for L in fiber_lengths_un_seen:
    e_1 = L / 100
    e_2 = -np.log10(6e-7)
    e_3 = 5e-3 * 100
    e_4 = np.log10(target_nx)
    X_test_un_seen.append([e_1, e_2, e_3, e_4])

X_test_un_seen = np.array(X_test_un_seen)
X_test_scaled_un_seen = scaler.transform(X_test_un_seen)
X_test_tensor_un_seen = torch.tensor(X_test_scaled_un_seen, dtype=torch.float32).to(device)
print(f"After filtering NaN in optimized_key_rates_un_seen, {len(optimized_key_rates_un_seen)} data points remain for n_X = 5e8.")

After filtering NaN in optimized_key_rates_un_seen, 100 data points remain for n_X = 5e8.


In [11]:
# %% [markdown]
# ### Prepare Evaluation Data for All n_X Values (10^4 to 10^9) Before Training

# Load the combined dataset
combined_file_path = '../Training_Data/n_X/good/cleaned_combined_datasets.json'
try:
    with open(combined_file_path, 'r') as f:
        combined_data = json.load(f)
    print(f"Loaded combined dataset with {len(combined_data)} n_X entries.")
except FileNotFoundError:
    raise FileNotFoundError(f"Combined dataset not found at {combined_file_path}.")

Loaded combined dataset with 6 n_X entries.


In [12]:
# Define the range of n_X values to evaluate
nx_values = [10**s for s in range(4, 10)]  # 10^4 to 10^9

# Dictionary to store evaluation data for each n_X
all_evaluation_data = {}

In [13]:
# # Prepare evaluation data from the test set
# nx_values = [10**s for s in range(4, 10)]
# all_evaluation_data = {}

# for nx in nx_values:
#     # Filter test data for this n_X
#     eval_data = [item for item in test_data if item['n_X'] == nx]
#     if not eval_data:
#         print(f"No test data found for n_X = {nx}. Skipping...")
#         continue
#     print(f"Extracted {len(eval_data)} test entries for n_X = {nx}.")

#     # Extract fiber lengths and optimized parameters
#     fiber_lengths = np.array([entry["fiber_length"] for entry in eval_data])
#     optimized_params_array = np.array([list(entry["optimized_params"].values()) for entry in eval_data])

#     # Compute optimized key rates
#     optimized_key_rates = []
#     valid_indices = []
#     for idx, (params, L) in enumerate(zip(optimized_params_array, fiber_lengths)):
#         key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7,
#                                   epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16,
#                                   e_mis=5e-3, P_ap=0, n_event=1)
#         if np.isnan(key_rate):
#             print(f"NaN detected in optimized_key_rates for n_X = {nx} at index {idx} with params {params}, fiber_length={L}")
#         else:
#             optimized_key_rates.append(key_rate)
#             valid_indices.append(idx)
#     optimized_key_rates = np.array(optimized_key_rates)

#     # Filter the evaluation data to exclude NaN key rates
#     fiber_lengths = fiber_lengths[valid_indices]
#     optimized_params_array = optimized_params_array[valid_indices]

#     # Prepare test inputs for evaluation
#     X_test = []
#     for L in fiber_lengths:
#         e_1 = L / 100
#         e_2 = -np.log10(6e-7)
#         e_3 = 5e-3 * 100
#         e_4 = np.log10(nx)
#         X_test.append([e_1, e_2, e_3, e_4])
#     X_test = np.array(X_test)
#     X_test_scaled = scaler.transform(X_test)
#     X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32).to(device)

#     # Store evaluation data
#     all_evaluation_data[nx] = {
#         'fiber_lengths': fiber_lengths,
#         'optimized_params_array': optimized_params_array,
#         'optimized_key_rates': optimized_key_rates,
#         'X_test_tensor': X_test_tensor
#     }
#     print(f"After filtering NaN, {len(optimized_key_rates)} data points remain for n_X = {nx}.")

In [14]:
# Extract and prepare evaluation data for each n_X
for nx in nx_values:
    nx_key = str(float(nx))
    if nx_key not in combined_data:
        print(f"No data found for n_X = {nx} in the combined dataset. Skipping...")
        continue
    evaluation_data = combined_data[nx_key]
    print(f"Extracted {len(evaluation_data)} entries for n_X = {nx}.")

    # Extract fiber lengths and optimized parameters
    fiber_lengths = np.array([entry["fiber_length"] for entry in evaluation_data])
    optimized_params_array = np.array([list(entry["optimized_params"].values()) for entry in evaluation_data])

    # Compute optimized key rates
    optimized_key_rates = []
    valid_indices = []
    for idx, (params, L) in enumerate(zip(optimized_params_array, fiber_lengths)):
        key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
        if np.isnan(key_rate):
            print(f"NaN detected in optimized_key_rates for n_X = {nx} at index {idx} with params {params}, fiber_length={L}")
        else:
            optimized_key_rates.append(key_rate)
            valid_indices.append(idx)
    optimized_key_rates = np.array(optimized_key_rates)

    # Filter the evaluation data to exclude NaN key rates
    fiber_lengths = fiber_lengths[valid_indices]
    optimized_params_array = optimized_params_array[valid_indices]

    # Prepare test inputs for evaluation
    X_test = []
    for L in fiber_lengths:
        e_1 = L / 100
        e_2 = -np.log10(6e-7)
        e_3 = 5e-3 * 100
        e_4 = np.log10(nx)
        X_test.append([e_1, e_2, e_3, e_4])
    X_test = np.array(X_test)
    X_test_scaled = scaler.transform(X_test)
    X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32).to(device)

    # Store all evaluation data in a dictionary
    all_evaluation_data[nx] = {
        'fiber_lengths': fiber_lengths,
        'optimized_params_array': optimized_params_array,
        'optimized_key_rates': optimized_key_rates,
        'X_test_tensor': X_test_tensor
    }
    print(f"After filtering NaN, {len(optimized_key_rates)} data points remain for n_X = {nx}.")

Extracted 736 entries for n_X = 10000.
After filtering NaN, 736 data points remain for n_X = 10000.
Extracted 855 entries for n_X = 100000.
After filtering NaN, 855 data points remain for n_X = 100000.
Extracted 902 entries for n_X = 1000000.
After filtering NaN, 902 data points remain for n_X = 1000000.
Extracted 927 entries for n_X = 10000000.
After filtering NaN, 927 data points remain for n_X = 10000000.
Extracted 942 entries for n_X = 100000000.
After filtering NaN, 942 data points remain for n_X = 100000000.
Extracted 948 entries for n_X = 1000000000.
After filtering NaN, 948 data points remain for n_X = 1000000000.


In [15]:
# %%
# Define the neural network model
class BB84NN(nn.Module):
    def __init__(self):
        super(BB84NN, self).__init__()
        self.fc1 = nn.Linear(4, 16)
        self.fc2 = nn.Linear(16, 32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [16]:
# Initialize model, loss, optimizer, and scheduler
model = BB84NN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-6)

# Lists to store metrics
train_losses = []
val_losses = []
learning_rates = []

In [17]:
def plot_keyrate_subplots(all_data, epoch, filename):
    """
    Plots key rates vs fiber length for multiple n_X values on separate subplots. all_data: dict with keys as n_X values and values containing fiber_lengths, optimized_key_rates, predicted_key_rates.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    fig.suptitle(f"Key Rates vs Fiber Length at Epoch {epoch+1} ($n_X = 10^s, s=4,5,...,9$)", fontsize=16)

    axes = axes.flatten()
    nx_values = sorted(all_data.keys())  # Ensure n_X values are sorted

    for i, nx in enumerate(nx_values):
        if i >= 6:  # Limit to 6 subplots
            break

        ax = axes[i]
        fiber_lengths = all_data[nx]['fiber_lengths']
        optimized_key_rates = all_data[nx]['optimized_key_rates']
        predicted_key_rates = all_data[nx].get('predicted_key_rates', None)

        # Filter out NaN values
        valid_mask = ~(np.isnan(optimized_key_rates))
        if predicted_key_rates is not None:
            valid_mask &= ~np.isnan(predicted_key_rates)
        fiber_lengths = fiber_lengths[valid_mask]
        optimized_key_rates = optimized_key_rates[valid_mask]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[valid_mask]

        # Find cutoff where optimized key rate becomes very small
        threshold = 1e-8
        cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
        if len(cutoff_idx) > 0:
            cutoff_idx = cutoff_idx[0]
        else:
            cutoff_idx = len(fiber_lengths)

        fiber_lengths = fiber_lengths[:cutoff_idx]
        optimized_key_rates = optimized_key_rates[:cutoff_idx]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[:cutoff_idx]

        # Plot key rates with adjusted linestyles and transparency
        ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linestyle='-', linewidth=2.0, alpha=1.0)  # Thicker line, semi-transparent
        if predicted_key_rates is not None and len(predicted_key_rates) > 0:
            ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r', label='Predicted Key Rate', 
                    linestyle=':',
                    # linestyle=(0, (1, 10)),  # Short dash (1 point) to look like a dot, 10-point gap
                    linewidth=6.0, alpha=1.0) # Thicker line, fully opaque
        
                
        # Compute the exponent for nx (e.g., nx = 10000 -> exponent = 4)
        exponent = int(np.log10(nx))
        # Set the title with LaTeX formatting for 10^exponent
        ax.set_title(f"$n_X = 10^{{{exponent}}}$")
        ax.grid(True)

        if i >= 3:
            ax.set_xlabel("Fiber Length (km)")
        if i % 3 == 0:
            ax.set_ylabel("Secret Key Rate per Pulse")

        ax.legend(loc='upper right')

    # Remove empty subplots
    for i in range(len(nx_values), 6):
        fig.delaxes(axes[i])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close(fig)

    print(f"Key rate subplots saved to {filename}")

In [18]:
def plot_parameters_subplots(all_data, epoch, filename):
    """
    Plots optimized and predicted parameters vs fiber length for multiple n_X values on separate subplots.
    all_data: dict with keys as n_X values and values containing fiber_lengths, optimized_params_array, predicted_params_array.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    fig.suptitle(f"Optimized and Predicted Parameters vs Fiber Length at Epoch {epoch+1} ($n_X = 10^s, s=4,5,...,9$)", fontsize=16)

    axes = axes.flatten()
    nx_values = sorted(all_data.keys())

    param_labels = ['$P_X$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$\mu_1$', '$\mu_2$']
    colors = ['blue', 'orange', 'green', 'red', 'purple']

    # Store handles and labels for the shared legend
    handles, labels = [], []

    for i, nx in enumerate(nx_values):
        if i >= 6:
            break

        ax = axes[i]
        fiber_lengths = all_data[nx]['fiber_lengths']
        optimized_params_array = all_data[nx]['optimized_params_array']
        predicted_params_array = all_data[nx].get('predicted_params_array', None)

        # Filter out NaN values in key rates to align with key rate plots
        optimized_key_rates = all_data[nx]['optimized_key_rates']
        valid_mask = ~(np.isnan(optimized_key_rates))
        fiber_lengths = fiber_lengths[valid_mask]
        optimized_params_array = optimized_params_array[valid_mask]
        if predicted_params_array is not None:
            predicted_params_array = predicted_params_array[valid_mask]

        # Truncate based on key rate cutoff
        threshold = 1e-8
        cutoff_idx = np.where(optimized_key_rates[valid_mask] <= threshold)[0]
        if len(cutoff_idx) > 0:
            cutoff_idx = cutoff_idx[0]
        else:
            cutoff_idx = len(fiber_lengths)

        fiber_lengths = fiber_lengths[:cutoff_idx]
        optimized_params_array = optimized_params_array[:cutoff_idx]
        if predicted_params_array is not None:
            predicted_params_array = predicted_params_array[:cutoff_idx]

        # Plot parameters with slight transparency for predicted lines
        for param_idx, (label, color) in enumerate(zip(param_labels, colors)):
            # Plot optimized line
            line_opt = ax.plot(fiber_lengths, optimized_params_array[:, param_idx], 
                               label=f'Optimized {label}', color=color, linestyle='-', linewidth=2.0, alpha=1.0)
            # Plot predicted line
            if predicted_params_array is not None:
                line_pred = ax.plot(fiber_lengths, predicted_params_array[:, param_idx], 
                                    label=f'Predicted {label}', color=color, linestyle=':', 
                                    linewidth=4.0, alpha=0.7)
            
            # Collect handles and labels only from the first subplot (i=0)
            if i == 0:
                handles.append(line_opt[0])
                labels.append(f'Optimized {label}')
                if predicted_params_array is not None:
                    handles.append(line_pred[0])
                    labels.append(f'Predicted {label}')

        # Compute the exponent for nx (e.g., nx = 10000 -> exponent = 4)
        exponent = int(np.log10(nx))
        # Set the title with LaTeX formatting for 10^exponent
        ax.set_title(f"$n_X = 10^{{{exponent}}}$")
        
        ax.set_ylim(0.0, 1.0)
        ax.grid(True)

        if i >= 3:
            ax.set_xlabel("Fiber Length (km)")
        if i % 3 == 0:
            ax.set_ylabel("Parameter Value")

    # Remove empty subplots
    for i in range(len(nx_values), 6):
        fig.delaxes(axes[i])

    # Add a single shared legend outside the entire grid
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(0.92, 0.5))

    # Adjust layout to make space for the legend on the right
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])  # Leave 10% space on the right for the shared legend
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Parameter subplots saved to {filename}")

In [19]:
# def plot_absolute_differences_keyrate_and_parameters(fiber_lengths, optimized_key_rates, optimized_params_array,
#                                                     predicted_key_rates, predicted_params_array, epoch, filename,
#                                                     nx=None, threshold=1e-8):
#     """
#     Plots absolute differences of key rates and parameters for a specific n_X value.

#     Parameters:
#     - fiber_lengths: Array of fiber lengths (km).
#     - optimized_key_rates: Array of optimized key rates.
#     - optimized_params_array: Array of optimized parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
#     - predicted_key_rates: Array of predicted key rates.
#     - predicted_params_array: Array of predicted parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
#     - epoch: Epoch number or string (e.g., 'final_test').
#     - filename: Output filename for the plot.
#     - nx: The n_X value for the plot title (optional).
#     - threshold: Key rate threshold for physical cutoff (default: 1e-8).
#     """
#     # Start timing
#     plot_start_time = time.time()

#     # Filter out NaN values
#     valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
#     fiber_lengths = fiber_lengths[valid_mask]
#     optimized_key_rates = optimized_key_rates[valid_mask]
#     predicted_key_rates = predicted_key_rates[valid_mask]
#     optimized_params_array = optimized_params_array[valid_mask]
#     predicted_params_array = predicted_params_array[valid_mask]

#     # Find cutoff where optimized key rate becomes very small
#     cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
#     if len(cutoff_idx) > 0:
#         cutoff_idx = cutoff_idx[0]
#     else:
#         cutoff_idx = len(fiber_lengths)

#     fiber_lengths = fiber_lengths[:cutoff_idx]
#     optimized_key_rates = optimized_key_rates[:cutoff_idx]
#     predicted_key_rates = predicted_key_rates[:cutoff_idx]
#     optimized_params_array = optimized_params_array[:cutoff_idx]
#     predicted_params_array = predicted_params_array[:cutoff_idx]

#     # Compute absolute differences
#     key_rate_abs_diff = predicted_key_rates - optimized_key_rates
#     param_abs_diffs = predicted_params_array - optimized_params_array
#     param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']

#     # Create figure with 6 subplots (2 rows, 3 columns)
#     fig = plt.figure(figsize=(18, 10))
#     gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    

#     # Plot absolute differences for parameters
#     for i in range(5):
#         row = (i + 1) // 3
#         col = (i + 1) % 3
#         ax = fig.add_subplot(gs[row, col])
#         ax.plot(fiber_lengths, param_abs_diffs[:, i], 'b-', label=f'Abs. Diff. {param_labels[i]}', linewidth=2.0)
#         ax.set_xlabel('Fiber Length (km)')
#         ax.set_ylabel('Absolute Difference')
#         ax.set_title(f'Abs. Diff. for {param_labels[i]}')
#         ax.grid(True)
#         ax.legend(loc='best')

#     # Plot absolute difference for key rate
#     ax_key = fig.add_subplot(gs[0, 0])
#     ax_key.plot(fiber_lengths, key_rate_abs_diff, 'g-', label='Predicted - Optimized Key Rate', linewidth=2.0)
#     ax_key.set_xlabel('Fiber Length (km)')
#     ax_key.set_ylabel('Absolute Difference')
#     ax_key.set_title('Absolute Difference in Key Rate')
#     ax_key.grid(True)
#     ax_key.legend(loc='best')
    
#     # Overall title
#     exponent = int(np.log10(nx)) if nx is not None else ''
#     fig.suptitle(f'Absolute Differences for Key Rate and Parameters ($n_X = 10^{{{exponent}}}$)', fontsize=16)

#     # Save and close
#     plt.tight_layout(rect=[0, 0, 1, 0.95])
#     plt.savefig(filename, dpi=300, bbox_inches='tight')
#     plt.show()
#     plt.close(fig)
#     print(f"Absolute differences plot saved to {filename}")

#     # Print plotting time
#     plot_time = time.time() - plot_start_time
#     print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [20]:
import matplotlib.pyplot as plt
import numpy as np
import time

def plot_absolute_differences_keyrate_and_parameters(fiber_lengths, optimized_key_rates, optimized_params_array,
                                                    predicted_key_rates, predicted_params_array, epoch, filename,
                                                    nx=None, threshold=1e-8):
    """
    Plots absolute differences of key rates and parameters for a specific n_X value.

    Parameters:
    - fiber_lengths: Array of fiber lengths (km).
    - optimized_key_rates: Array of optimized key rates.
    - optimized_params_array: Array of optimized parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - predicted_key_rates: Array of predicted key rates.
    - predicted_params_array: Array of predicted parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - epoch: Epoch number or string (e.g., 'final_test').
    - filename: Output filename for the plot.
    - nx: The n_X value for the plot title (optional).
    - threshold: Key rate threshold for physical cutoff (default: 1e-8).
    """
    # Start timing
    plot_start_time = time.time()

    # Filter out NaN values
    valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
    fiber_lengths = fiber_lengths[valid_mask]
    optimized_key_rates = optimized_key_rates[valid_mask]
    predicted_key_rates = predicted_key_rates[valid_mask]
    optimized_params_array = optimized_params_array[valid_mask]
    predicted_params_array = predicted_params_array[valid_mask]

    # Find cutoff where optimized key rate becomes very small
    cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
    if len(cutoff_idx) > 0:
        cutoff_idx = cutoff_idx[0]
    else:
        cutoff_idx = len(fiber_lengths)

    fiber_lengths = fiber_lengths[:cutoff_idx]
    optimized_key_rates = optimized_key_rates[:cutoff_idx]
    predicted_key_rates = predicted_key_rates[:cutoff_idx]
    optimized_params_array = optimized_params_array[:cutoff_idx]
    predicted_params_array = predicted_params_array[:cutoff_idx]

    # Compute absolute differences
    key_rate_abs_diff = predicted_key_rates - optimized_key_rates
    param_abs_diffs = predicted_params_array - optimized_params_array
    param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']

    # Create figure with 6 subplots (2 rows, 3 columns)
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    # Define subplot positions for parameters to avoid bottom-right
    param_positions = [
        (0, 0),  # Top-left
        (0, 1),  # Top-middle
        (0, 2),  # Top-right
        (1, 0),  # Bottom-left
        (1, 1),  # Bottom-middle
    ]

    # Plot absolute differences for parameters
    for i in range(5):
        row, col = param_positions[i]
        ax = fig.add_subplot(gs[row, col])
        ax.plot(fiber_lengths, param_abs_diffs[:, i], 'b-', label=f'Abs. Diff. {param_labels[i]}', linewidth=2.0)
        ax.set_xlabel('Fiber Length (km)')
        ax.set_ylabel('Absolute Difference')
        ax.set_title(f'Abs. Diff. for {param_labels[i]}')
        ax.grid(True)
        ax.legend(loc='best')

    # Plot absolute difference for key rate in bottom-right (gs[1, 2])
    ax_key = fig.add_subplot(gs[1, 2])
    ax_key.plot(fiber_lengths, key_rate_abs_diff, 'g-', label='Predicted - Optimized Key Rate', linewidth=2.0)
    ax_key.set_xlabel('Fiber Length (km)')
    ax_key.set_ylabel('Absolute Difference')
    ax_key.set_title('Absolute Difference in Key Rate')
    ax_key.grid(True)
    ax_key.legend(loc='best')
    
    # Overall title
    exponent = int(np.log10(nx)) if nx is not None else ''
    fig.suptitle(f'Absolute Differences for Key Rate and Parameters ($n_X = 10^{{{exponent}}}$)', fontsize=16)

    # Save and close
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close(fig)
    print(f"Absolute differences plot saved to {filename}")

    # Print plotting time
    plot_time = time.time() - plot_start_time
    print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [21]:
def plot_relative_errors(fiber_lengths, optimized_key_rates, optimized_params_array, 
                        predicted_key_rates, predicted_params_array, epoch, filename, 
                        nx=None, threshold=1e-8):
    """
    Plots relative errors of predicted vs. optimized parameters and key rate for a specific n_X.
    """
    plot_start_time = time.time()

    # Filter out NaN values
    valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
    fiber_lengths = fiber_lengths[valid_mask]
    optimized_key_rates = optimized_key_rates[valid_mask]
    predicted_key_rates = predicted_key_rates[valid_mask]
    optimized_params_array = optimized_params_array[valid_mask]
    predicted_params_array = predicted_params_array[valid_mask]

    # Find cutoff where optimized key rate becomes very small
    cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
    if len(cutoff_idx) > 0:
        cutoff_idx = cutoff_idx[0]
    else:
        cutoff_idx = len(fiber_lengths)

    fiber_lengths = fiber_lengths[:cutoff_idx]
    optimized_key_rates = optimized_key_rates[:cutoff_idx]
    predicted_key_rates = predicted_key_rates[:cutoff_idx]
    optimized_params_array = optimized_params_array[:cutoff_idx]
    predicted_params_array = predicted_params_array[:cutoff_idx]

    # Check for negative optimized key rates (shouldn't happen based on data filtering)
    if np.any(optimized_key_rates < 0):
        print(f"Warning: {np.sum(optimized_key_rates < 0)} optimized key rates are negative for n_X={nx}. Min value: {np.min(optimized_key_rates)}")
    
    # Check for negative predicted key rates (indicates insecure protocol)
    negative_pred_mask = predicted_key_rates < 0
    if np.any(negative_pred_mask):
        print(f"Found {np.sum(negative_pred_mask)} negative predicted key rates for n_X={nx}. Min value: {np.min(predicted_key_rates)}")
        # Optionally print fiber lengths where this occurs
        negative_fiber_lengths = fiber_lengths[negative_pred_mask]
        print(f"Fiber lengths with negative predicted key rates: {negative_fiber_lengths[:5]} (first 5 shown)")

    # Compute relative errors for parameters
    relative_errors = []
    param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
    for i in range(5):
        optimized = optimized_params_array[:, i]
        predicted = predicted_params_array[:, i]
        denominator = np.maximum(optimized, 1e-10)
        rel_error = (predicted - optimized) / denominator
        relative_errors.append(rel_error)
    
    # Key rate relative error
    denominator = np.maximum(optimized_key_rates, threshold)
    key_rate_rel_error = (predicted_key_rates - optimized_key_rates) / denominator

    # Create figure with 6 subplots (2 rows, 3 columns)
    fig = plt.figure(figsize=(18, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    # Plot relative errors for parameters
    for i in range(5):
        row = i // 3
        col = i % 3
        ax = fig.add_subplot(gs[row, col])
        ax.plot(fiber_lengths, relative_errors[i], 'b-', label=f'Relative Error {param_labels[i]}')
        ax.set_xlabel('Fiber Length (km)')
        ax.set_ylabel('Relative Error')
        ax.set_title(f'Relative Error for {param_labels[i]}')
        ax.grid(True)
        ax.legend(loc='best')

    # Key rate subplot with linear scale
    ax_key = fig.add_subplot(gs[1, 2])
    ax_key.plot(fiber_lengths, key_rate_rel_error, 'r-', label='Relative Error Key Rate')
    ax_key.set_xlabel('Fiber Length (km)')
    ax_key.set_ylabel('Relative Error')
    ax_key.set_title('Relative Error for Key Rate')
    ax_key.set_ylim(-2, 2)  # Set a reasonable linear range
    ax_key.grid(True)
    ax_key.legend(loc='best')

    # Overall title
    exponent = int(np.log10(nx)) if nx is not None else ''
    fig.suptitle(f'Relative Errors of Parameters for $n_X = 10^{{{exponent}}}$', fontsize=16)

    # Save and close
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Relative error plot saved to {filename}")

    # Print plotting time
    plot_time = time.time() - plot_start_time
    print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [22]:
# def plot_relative_errors(fiber_lengths, optimized_key_rates, optimized_params_array, 
#                         predicted_key_rates, predicted_params_array, epoch, filename, 
#                         nx=None, threshold=1e-8):
#     """
#     Plots relative errors of predicted vs. optimized parameters and key rate for a specific n_X.
#     """
#     plot_start_time = time.time()

#     # Filter out NaN values
#     valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
#     fiber_lengths = fiber_lengths[valid_mask]
#     optimized_key_rates = optimized_key_rates[valid_mask]
#     predicted_key_rates = predicted_key_rates[valid_mask]
#     optimized_params_array = optimized_params_array[valid_mask]
#     predicted_params_array = predicted_params_array[valid_mask]

#     # Find cutoff where optimized key rate becomes very small
#     cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
#     if len(cutoff_idx) > 0:
#         cutoff_idx = cutoff_idx[0]
#     else:
#         cutoff_idx = len(fiber_lengths)

#     fiber_lengths = fiber_lengths[:cutoff_idx]
#     optimized_key_rates = optimized_key_rates[:cutoff_idx]
#     predicted_key_rates = predicted_key_rates[:cutoff_idx]
#     optimized_params_array = optimized_params_array[:cutoff_idx]
#     predicted_params_array = predicted_params_array[:cutoff_idx]

#     # Compute relative errors for parameters
#     relative_errors = []
#     param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
#     for i in range(5):
#         optimized = optimized_params_array[:, i]
#         predicted = predicted_params_array[:, i]
#         denominator = np.maximum(optimized, 1e-10)
#         rel_error = (predicted - optimized) / denominator
#         relative_errors.append(rel_error)
    
#     # Key rate relative error
#     denominator = np.maximum(optimized_key_rates, threshold)
#     key_rate_rel_error = (predicted_key_rates - optimized_key_rates) / denominator
#     # Remove capping to see the true relative err
# or
#     # key_rate_rel_error = np.clip(key_rate_rel_error, -10, 10)

#     # Create figure with 6 subplots (2 rows, 3 columns)
#     fig = plt.figure(figsize=(18, 10))
#     gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

#     # Plot relative errors for parameters
#     for i in range(5):
#         row = i // 3
#         col = i % 3
#         ax = fig.add_subplot(gs[row, col])
#         ax.plot(fiber_lengths, relative_errors[i], 'b-', label=f'Relative Error {param_labels[i]}')
#         ax.set_xlabel('Fiber Length (km)')
#         ax.set_ylabel('Relative Error')
#         ax.set_title(f'Relative Error for {param_labels[i]}')
#         ax.grid(True)
#         ax.legend(loc='best')

#     # Key rate subplot with linear scale
#     ax_key = fig.add_subplot(gs[1, 2])
#     ax_key.plot(fiber_lengths, key_rate_rel_error, 'r-', label='Relative Error Key Rate')
#     ax_key.set_xlabel('Fiber Length (km)')
#     ax_key.set_ylabel('Relative Error')
#     ax_key.set_title('Relative Error for Key Rate')
#     ax_key.set_ylim(-2, 2)  # Set a reasonable linear range
#     ax_key.grid(True)
#     ax_key.legend(loc='best')

#     # Overall title
#     exponent = int(np.log10(nx)) if nx is not None else ''
#     fig.suptitle(f'Relative Errors of Parameters for $n_X = 10^{{{exponent}}}$', fontsize=16)

#     # Save and close
#     plt.tight_layout(rect=[0, 0, 1, 0.95])
#     plt.savefig(filename, dpi=300, bbox_inches='tight')
#     plt.close(fig)
#     print(f"Relative error plot saved to {filename}")

#     # Print plotting time
#     plot_time = time.time() - plot_start_time
#     print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [23]:
# def plot_relative_errors(fiber_lengths, optimized_key_rates, optimized_params_array, 
#                         predicted_key_rates, predicted_params_array, epoch, filename, 
#                         nx=None, threshold=1e-8):
#     """
#     Plots relative errors of predicted vs. optimized parameters and key rate for a specific n_X.
    
#     Parameters:
#     - fiber_lengths: Array of fiber lengths (km).
#     - optimized_key_rates: Array of optimized key rates.
#     - optimized_params_array: Array of optimized parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
#     - predicted_key_rates: Array of predicted key rates.
#     - predicted_params_array: Array of predicted parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
#     - epoch: Epoch number or string (e.g., 'final_test').
#     - filename: Output filename for the plot.
#     - nx: The n_X value for the plot title (optional).
#     - threshold: Key rate threshold for physical cutoff (default: 1e-8).
#     """
#     # Start timing
#     plot_start_time = time.time()

#     # Filter out NaN values
#     valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
#     fiber_lengths = fiber_lengths[valid_mask]
#     optimized_key_rates = optimized_key_rates[valid_mask]
#     predicted_key_rates = predicted_key_rates[valid_mask]
#     optimized_params_array = optimized_params_array[valid_mask]
#     predicted_params_array = predicted_params_array[valid_mask]

#     # Find cutoff where optimized key rate becomes very small
#     cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
#     if len(cutoff_idx) > 0:
#         cutoff_idx = cutoff_idx[0]
#     else:
#         cutoff_idx = len(fiber_lengths)

#     fiber_lengths = fiber_lengths[:cutoff_idx]
#     optimized_key_rates = optimized_key_rates[:cutoff_idx]
#     predicted_key_rates = predicted_key_rates[:cutoff_idx]
#     optimized_params_array = optimized_params_array[:cutoff_idx]
#     predicted_params_array = predicted_params_array[:cutoff_idx]

#     # Compute relative errors for parameters
#     relative_errors = []
#     param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
#     for i in range(5):
#         optimized = optimized_params_array[:, i]
#         predicted = predicted_params_array[:, i]
#         # Avoid division by zero with a small threshold
#         denominator = np.maximum(optimized, 1e-10)
#         rel_error = (predicted - optimized) / denominator
#         relative_errors.append(rel_error)
    
#     # Key rate relative error
#     # Key rate relative error
#     denominator = np.maximum(optimized_key_rates, threshold)
#     key_rate_rel_error = (predicted_key_rates - optimized_key_rates) / denominator
# # Remove the capping to see the true relative error
# # key_rate_rel_error = np.clip(key_rate_rel_error, -10, 10)
#     # Cap the relative error for key rate to prevent extreme spikes
#     key_rate_rel_error = np.clip(key_rate_rel_error, -10, 10)

#     # Create figure with 6 subplots (2 rows, 3 columns)
#     fig = plt.figure(figsize=(18, 10))
#     gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

#     # Plot relative errors for parameters
#     for i in range(5):
#         row = i // 3
#         col = i % 3
#         ax = fig.add_subplot(gs[row, col])
#         ax.plot(fiber_lengths, relative_errors[i], 'b-', label=f'Relative Error {param_labels[i]}')
#         ax.set_xlabel('Fiber Length (km)')
#         ax.set_ylabel('Relative Error')
#         ax.set_title(f'Relative Error for {param_labels[i]}')
#         ax.grid(True)
#         ax.legend(loc='best')

#     # Key rate subplot with logarithmic scale
#     ax_key = fig.add_subplot(gs[1, 2])
#     ax_key.plot(fiber_lengths, key_rate_rel_error, 'r-', label='Relative Error Key Rate')
#     ax_key.set_xlabel('Fiber Length (km)')
#     ax_key.set_ylabel('Relative Error')
#     ax_key.set_title('Relative Error for Key Rate')
#     ax_key.set_yscale('symlog', linthresh=0.1)  # Symmetric logarithmic scale
#     ax_key.grid(True)
#     ax_key.legend(loc='best')

#     # Overall title
#     exponent = int(np.log10(nx)) if nx is not None else ''
#     fig.suptitle(f'Relative Errors of Parameters for $n_X = 10^{{{exponent}}}$', fontsize=16)

#     # Save and close
#     plt.tight_layout(rect=[0, 0, 1, 0.95])
#     plt.savefig(filename, dpi=300, bbox_inches='tight')
#     plt.close(fig)
#     print(f"Relative error plot saved to {filename}")

#     # Print plotting time
#     plot_time = time.time() - plot_start_time
#     print(f"--- Plotting Time: {plot_time:.2f} seconds ---")

In [24]:
import matplotlib.pyplot as plt
import numpy as np

def plot_keyrate_subplots_with_relative_error(all_data, epoch, filename):
    """
    Plots key rates and relative errors vs fiber length for multiple n_X values on separate subplots.
    
    Parameters:
    - all_data: Dictionary with keys as n_X values and values containing fiber_lengths, optimized_key_rates, predicted_key_rates.
    - epoch: Epoch number or string (e.g., 'final_test').
    - filename: Output filename for the plot.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True)
    fig.suptitle(f"Key Rates and Relative Error vs Fiber Length at Epoch {epoch+1} ($n_X = 10^s, s=4,5,...,9$)", fontsize=16)

    axes = axes.flatten()
    nx_values = sorted(all_data.keys())

    for i, nx in enumerate(nx_values):
        if i >= 6:
            break

        ax = axes[i]
        fiber_lengths = all_data[nx]['fiber_lengths']
        optimized_key_rates = all_data[nx]['optimized_key_rates']
        predicted_key_rates = all_data[nx].get('predicted_key_rates', None)

        # Filter out NaN values
        valid_mask = ~(np.isnan(optimized_key_rates))
        if predicted_key_rates is not None:
            valid_mask &= ~np.isnan(predicted_key_rates)
        fiber_lengths = fiber_lengths[valid_mask]
        optimized_key_rates = optimized_key_rates[valid_mask]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[valid_mask]

        # Apply key rate cutoff
        threshold = 1e-8
        cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
        if len(cutoff_idx) > 0:
            cutoff_idx = cutoff_idx[0]
        else:
            cutoff_idx = len(fiber_lengths)

        fiber_lengths = fiber_lengths[:cutoff_idx]
        optimized_key_rates = optimized_key_rates[:cutoff_idx]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[:cutoff_idx]

        # Plot key rates
        ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linewidth=2.0)
        if predicted_key_rates is not None:
            ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r:', label='Predicted Key Rate', linewidth=4.0)

        # Plot relative error on a secondary axis
        ax2 = ax.twinx()
# denominator = np.maximum(optimized_key_rates, 1e-8)
        denominator = np.maximum(optimized_key_rates, threshold)
        rel_error = (predicted_key_rates - optimized_key_rates) / denominator
        rel_error = np.clip(rel_error, -10, 10)
        ax2.plot(fiber_lengths, rel_error, 'g--', label='Relative Error', alpha=0.5)
        ax2.set_ylabel('Relative Error', color='g')
        ax2.tick_params(axis='y', labelcolor='g')
        ax2.set_ylim(-10, 10)

        # Set title and labels
        exponent = int(np.log10(nx))
        ax.set_title(f"$n_X = 10^{{{exponent}}}$")
        ax.grid(True)

        if i >= 3:
            ax.set_xlabel("Fiber Length (km)")
        if i % 3 == 0:
            ax.set_ylabel("Log10(Secret Key Rate per Pulse)")
        ax.legend(loc='upper left')
        ax2.legend(loc='upper right')

    # Remove empty subplots
    for i in range(len(nx_values), 6):
        fig.delaxes(axes[i])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close(fig)
    print(f"Key rate subplots with relative error saved to {filename}")

In [25]:
# Function to save predicted data as JSON
def save_dataset_to_json(plot_data, epoch, output_dir='datasets_json'):
    """
    Saves the predicted key rates and parameters for each n_X to JSON files.
    
    Parameters:
    - plot_data: Dictionary with n_X as keys and data (fiber_lengths, e_values, predicted key rates, predicted params) as values.
    - epoch: The epoch number.
    - output_dir: Directory to save the JSON files.
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for nx, data in plot_data.items():
        fiber_lengths = data['fiber_lengths']
        e_values = data['e_values']  # List of [e_1, e_2, e_3, e_4] for each fiber length
        predicted_key_rates = data['predicted_key_rates']
        predicted_params_array = data['predicted_params_array']

        # Create a dictionary in the desired JSON structure
        dataset = {}
        nx_key = str(float(nx))  # e.g., "10000.0"
        dataset[nx_key] = []

        for i in range(len(fiber_lengths)):
            entry = {
                "fiber_length": float(fiber_lengths[i]),
                "e_1": float(e_values[i][0]),
                "e_2": float(e_values[i][1]),
                "e_3": float(e_values[i][2]),
                "e_4": float(e_values[i][3]),
                "predicted_key_rate": float(predicted_key_rates[i]),  # Predicted key rate only
                "predicted_params": {
                    "mu_1": float(predicted_params_array[i, 0]),
                    "mu_2": float(predicted_params_array[i, 1]),
                    "P_mu_1": float(predicted_params_array[i, 2]),
                    "P_mu_2": float(predicted_params_array[i, 3]),
                    "P_X_value": float(predicted_params_array[i, 4])
                }
            }
            dataset[nx_key].append(entry)

        # Save to JSON
        filename = os.path.join(output_dir, f'predicted_keyrate_params_nx_{nx:.0e}_epoch_{epoch}.json')
        try:
            with open(filename, 'w') as f:
                json.dump(dataset, f, indent=4)
            print(f"Saved predicted dataset to {filename}")
        except Exception as e:
            print(f"Error saving predicted dataset to {filename}: {e}")

In [26]:
# import time

# # Start total training time measurement
# total_start_time = time.time()
# # Training loop
# num_epochs = 5000
# for epoch in range(num_epochs):
#     start_time = time.time()

#     # Training phase
#     model.train()
#     running_loss = 0.0
#     for inputs, targets in train_loader:
#         inputs, targets = inputs.to(device), targets.to(device)
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item() * inputs.size(0)
#     epoch_loss = running_loss / len(train_loader.dataset)
#     train_losses.append(epoch_loss)

#     # Validation phase
#     model.eval()
#     val_running_loss = 0.0
#     with torch.no_grad():
#         for inputs, targets in val_loader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, targets)
#             val_running_loss += loss.item() * inputs.size(0)
#     val_loss = val_running_loss / len(val_loader.dataset)
#     val_losses.append(val_loss)

#     # Update learning rate
#     scheduler.step(val_loss)
#     current_lr = optimizer.param_groups[0]['lr']
#     learning_rates.append(current_lr)

#     # Evaluate and plot at the first and last epochs
#     eval_time = 0
#     plot_time = 0
#     if epoch == 0 or epoch == num_epochs - 1:
#         model.eval()
#         with torch.no_grad():
#             # Prepare data for all n_X values
#             plot_data = {}
#             eval_start = time.time()
#             for nx, eval_data in all_evaluation_data.items():
#                 fiber_lengths = eval_data['fiber_lengths']
#                 optimized_params_array = eval_data['optimized_params_array']
#                 optimized_key_rates = eval_data['optimized_key_rates']
#                 X_test_tensor = eval_data['X_test_tensor']

#                 # Compute predicted parameters and key rates
#                 predicted_params_scaled = model(X_test_tensor).cpu().numpy()
#                 predicted_params_scaled = np.clip(predicted_params_scaled, 0, 1)
#                 predicted_params = y_scaler.inverse_transform(predicted_params_scaled)
#                 predicted_params[:, 0] = np.maximum(predicted_params[:, 0], 1e-6)
#                 predicted_params[:, 1] = np.maximum(predicted_params[:, 1], 1e-6)
#                 predicted_params[:, 2] = np.clip(predicted_params[:, 2], 0, 1)
#                 predicted_params[:, 3] = np.clip(predicted_params[:, 3], 0, 1)
#                 predicted_params[:, 4] = np.clip(predicted_params[:, 4], 0, 1)
#                 predicted_key_rates = []
#                 for params, L in zip(predicted_params, fiber_lengths):
#                     key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
#                     predicted_key_rates.append(key_rate)
#                 predicted_key_rates = np.array(predicted_key_rates)

#                 # Store data for plotting
#                 plot_data[nx] = {
#                     'fiber_lengths': fiber_lengths,
#                     'optimized_key_rates': optimized_key_rates,
#                     'predicted_key_rates': predicted_key_rates,
#                     'optimized_params_array': optimized_params_array,
#                     'predicted_params_array': predicted_params
#                 }

#                 # Plot relative errors for this n_X
#                 plot_relative_errors(
#                     fiber_lengths, optimized_key_rates, optimized_params_array,
#                     predicted_key_rates, predicted_params, epoch,
#                     f'relative_error_nx_{nx:.0e}_epoch_{epoch}.png', nx=nx
#                 )
#                 # Plot absolute differences for this n_X
#                 plot_absolute_differences_keyrate_and_parameters(
#                     fiber_lengths, optimized_key_rates, optimized_params_array,
#                     predicted_key_rates, predicted_params, epoch,
#                     filename=f'absolute_differences_nx_{nx:.0e}.png', nx=nx
#                 )

#             eval_time += time.time() - eval_start

#             # Plot subplots at the first and last epochs
#             plot_start = time.time()
#             if epoch == 0:
#                 plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_first_epoch.png')
#                 plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_first_epoch.png')
#             if epoch == num_epochs - 1:
#                 plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_last_epoch.png')
#                 plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_last_epoch.png')
#                 # Optionally, plot key rate subplots with relative error
#                 plot_keyrate_subplots_with_relative_error(plot_data, epoch, f'keyrate_subplots_with_rel_error_last_epoch.png')
#             plot_time += time.time() - plot_start

#             torch.save(model.state_dict(), 'bb84_nn_model.pth')
#             print("Model saved to bb84_nn_model.pth")

#     # Print epoch results with timing and learning rate
#     epoch_time = time.time() - start_time
#     print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, "
#           f"Learning Rate: {current_lr:.6f}, Time: {epoch_time:.2f}s (Eval: {eval_time:.2f}s, Plot: {plot_time:.2f}s)")

# # Print total training time
# total_training_time = time.time() - total_start_time
# print(f"--- Total Training Time: {total_training_time:.2f} seconds ---")

In [27]:
# Start total training time measurement
total_start_time = time.time()
# Training loop
num_epochs = 5000
for epoch in range(num_epochs):
    start_time = time.time()

    # Training phase
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_loss)

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_running_loss += loss.item() * inputs.size(0)
    val_loss = val_running_loss / len(val_loader.dataset)
    val_losses.append(val_loss)

    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

    # Evaluate and plot at the first and last epochs
    eval_time = 0
    plot_time = 0
    if epoch == 0 or epoch == num_epochs - 1:
        model.eval()
        with torch.no_grad():
            # Prepare data for all n_X values
            plot_data = {}
            eval_start = time.time()
            for nx, eval_data in all_evaluation_data.items():
                fiber_lengths = eval_data['fiber_lengths']
                optimized_params_array = eval_data['optimized_params_array']
                optimized_key_rates = eval_data['optimized_key_rates']
                X_test_tensor = eval_data['X_test_tensor']

                # Compute e_1, e_2, e_3, e_4 for each fiber length
                e_values = []
                for L in fiber_lengths:
                    e_1 = L / 100
                    e_2 = -np.log10(6e-7)
                    e_3 = 5e-3 * 100
                    e_4 = np.log10(nx)
                    e_values.append([e_1, e_2, e_3, e_4])

                # Compute predicted parameters and key rates
                predicted_params_scaled = model(X_test_tensor).cpu().numpy()
                predicted_params_scaled = np.clip(predicted_params_scaled, 0, 1)
                predicted_params = y_scaler.inverse_transform(predicted_params_scaled)
                predicted_params[:, 0] = np.maximum(predicted_params[:, 0], 1e-6)
                predicted_params[:, 1] = np.maximum(predicted_params[:, 1], 1e-6)
                predicted_params[:, 2] = np.clip(predicted_params[:, 2], 0, 1)
                predicted_params[:, 3] = np.clip(predicted_params[:, 3], 0, 1)
                predicted_params[:, 4] = np.clip(predicted_params[:, 4], 0, 1)
                predicted_key_rates = []
                for params, L in zip(predicted_params, fiber_lengths):
                    key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
                    predicted_key_rates.append(key_rate)
                predicted_key_rates = np.array(predicted_key_rates)

                # Store data for plotting and saving
                plot_data[nx] = {
                    'fiber_lengths': fiber_lengths,
                    'e_values': e_values,  # Store e_1, e_2, e_3, e_4
                    'optimized_key_rates': optimized_key_rates,
                    'predicted_key_rates': predicted_key_rates,
                    'optimized_params_array': optimized_params_array,
                    'predicted_params_array': predicted_params
                }

                # Plot relative errors for this n_X
                plot_relative_errors(
                    fiber_lengths, optimized_key_rates, optimized_params_array,
                    predicted_key_rates, predicted_params, epoch,
                    f'relative_error_nx_{nx:.0e}_epoch_{epoch}.png', nx=nx
                )
                # Plot absolute differences for this n_X
                plot_absolute_differences_keyrate_and_parameters(
                    fiber_lengths, optimized_key_rates, optimized_params_array,
                    predicted_key_rates, predicted_params, epoch,
                    filename=f'absolute_differences_nx_{nx:.0e}.png', nx=nx
                )

            eval_time += time.time() - eval_start

            # Save the predicted datasets in JSON format
            save_dataset_to_json(plot_data, epoch)

            # Plot subplots at the first and last epochs
            plot_start = time.time()
            if epoch == 0:
                plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_first_epoch.png')
                plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_first_epoch.png')
            if epoch == num_epochs - 1:
                plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_last_epoch.png')
                plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_last_epoch.png')
                # Optionally, plot key rate subplots with relative error
                plot_keyrate_subplots_with_relative_error(plot_data, epoch, f'keyrate_subplots_with_rel_error_last_epoch.png')
            plot_time += time.time() - plot_start

            torch.save(model.state_dict(), 'bb84_nn_model.pth')
            print("Model saved to bb84_nn_model.pth")

    # Print epoch results with timing and learning rate
    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Learning Rate: {current_lr:.6f}, Time: {epoch_time:.2f}s (Eval: {eval_time:.2f}s, Plot: {plot_time:.2f}s)")

# Print total training time
total_training_time = time.time() - total_start_time
print(f"--- Total Training Time: {total_training_time:.2f} seconds ---")

objective returned NaN/Inf for params [0.44669247 0.2409524  0.156103   0.41808417 0.71528256], L=0.0, nx=10000: nan
objective returned NaN/Inf for params [0.44669127 0.24089167 0.15610468 0.4178427  0.71488273], L=0.2002002002002002, nx=10000: nan
objective returned NaN/Inf for params [0.44669005 0.24083094 0.15610638 0.4176012  0.714483  ], L=0.4004004004004004, nx=10000: nan
objective returned NaN/Inf for params [0.4466888  0.24077024 0.15610807 0.41735974 0.71408325], L=0.6006006006006006, nx=10000: nan
objective returned NaN/Inf for params [0.44668794 0.24070756 0.1561078  0.41711578 0.71367866], L=0.8008008008008008, nx=10000: nan
objective returned NaN/Inf for params [0.4466871  0.24064463 0.15610725 0.41687143 0.71327364], L=1.001001001001001, nx=10000: nan
objective returned NaN/Inf for params [0.44668627 0.24058169 0.15610671 0.4166271  0.7128684 ], L=1.2012012012012012, nx=10000: nan
objective returned NaN/Inf for params [0.4466854  0.24051875 0.15610614 0.4163828  0.7124632

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+04_epoch_0.png
--- Plotting Time: 1.23 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_nx_1e+04.png
--- Plotting Time: 1.09 seconds ---
objective returned negative key rate -1e+250 for params [0.44349077 0.24490276 0.17014915 0.42767176 0.737245  ], L=0.0, nx=100000
objective returned negative key rate -1e+250 for params [0.4434859  0.2448131  0.17015651 0.4273491  0.7367059 ], L=0.2002002002002002, nx=100000
objective returned negative key rate -1e+250 for params [0.44348106 0.24472348 0.17016387 0.42702642 0.7361669 ], L=0.4004004004004004, nx=100000
objective returned negative key rate -1e+250 for params [0.4434762  0.24463384 0.17017122 0.42670375 0.7356277 ], L=0.6006006006006006, nx=100000
objective returned negative key rate -1e+250 for params [0.44347137 0.24454422 0.17017858 0.42638114 0.7350887 ], L=0.8008008008008008, nx=100000
objective returned negative key rate -1e+250 for params [0.44346648 0.24445459 0.17018594 0.42605847 0.73454964], L=1.001001001001001, nx=100000
objective returned negative key rat

  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_nx_1e+05.png
--- Plotting Time: 1.09 seconds ---
objective returned negative key rate -1e+250 for params [0.44076607 0.25530553 0.18275614 0.4596907  0.79228467], L=0.0, nx=1000000
objective returned negative key rate -1e+250 for params [0.44074416 0.25511822 0.18278117 0.45905462 0.79120404], L=0.4004004004004004, nx=1000000
objective returned negative key rate -1e+250 for params [0.44073322 0.25502455 0.18279369 0.45873648 0.7906636 ], L=0.6006006006006006, nx=1000000
objective returned negative key rate -1e+250 for params [0.44072226 0.2549309  0.18280621 0.45841846 0.7901233 ], L=0.8008008008008008, nx=1000000
objective returned negative key rate -1e+250 for params [0.44071132 0.25483724 0.18281873 0.45810038 0.7895829 ], L=1.001001001001001, nx=1000000
objective returned negative key rate -1e+250 for params [0.44068938 0.25464994 0.18284377 0.45746434 0.7885023 ], L=1.4014014014014016, nx=1000000
objective returned negative k

  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()


Absolute differences plot saved to absolute_differences_nx_1e+06.png
--- Plotting Time: 1.08 seconds ---
objective returned negative key rate -1e+250 for params [0.438245   0.27097562 0.19017224 0.4943575  0.86906093], L=0.0, nx=10000000
objective returned negative key rate -1e+250 for params [0.43823132 0.27087304 0.19019267 0.49401915 0.86847425], L=0.2002002002002002, nx=10000000
objective returned negative key rate -1e+250 for params [0.4382177  0.27077046 0.1902131  0.49368078 0.86788756], L=0.4004004004004004, nx=10000000
objective returned negative key rate -1e+250 for params [0.43819037 0.27056527 0.19025393 0.49300402 0.866714  ], L=0.8008008008008008, nx=10000000
objective returned negative key rate -1e+250 for params [0.4381494  0.2702575  0.1903152  0.49198902 0.8649539 ], L=1.4014014014014016, nx=10000000
objective returned negative key rate -1e+250 for params [0.4381357  0.27015492 0.19033562 0.49165067 0.86436725], L=1.6016016016016017, nx=10000000
objective returned neg

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+07_epoch_0.png
--- Plotting Time: 1.20 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_nx_1e+07.png
--- Plotting Time: 1.09 seconds ---
objective returned negative key rate -1e+250 for params [0.43608138 0.28954342 0.19650832 0.53647006 0.8765765 ], L=0.0, nx=100000000
objective returned negative key rate -1e+250 for params [0.436068   0.28949133 0.19652261 0.53625286 0.8765765 ], L=0.2002002002002002, nx=100000000
objective returned negative key rate -1e+250 for params [0.43605462 0.28943923 0.19653687 0.5360355  0.8765765 ], L=0.4004004004004004, nx=100000000
objective returned negative key rate -1e+250 for params [0.43604124 0.28938717 0.19655119 0.5358182  0.8765765 ], L=0.6006006006006006, nx=100000000
objective returned negative key rate -1e+250 for params [0.43602785 0.28933507 0.19656548 0.53560084 0.8765765 ], L=0.8008008008008008, nx=100000000
objective returned negative key rate -1e+250 for params [0.43601444 0.289283   0.19657977 0.5353835  0.8765765 ], L=1.001001001001001, nx=100000000
objective returne

  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_nx_1e+08.png
--- Plotting Time: 1.09 seconds ---
objective returned negative key rate -1e+250 for params [0.43471947 0.30924883 0.20689605 0.5803682  0.8765765 ], L=0.0, nx=1000000000
objective returned negative key rate -1e+250 for params [0.43469313 0.30914453 0.20691626 0.57993793 0.8765765 ], L=0.4004004004004004, nx=1000000000
objective returned negative key rate -1e+250 for params [0.43466675 0.3090402  0.2069365  0.57950765 0.8765765 ], L=0.8008008008008008, nx=1000000000
objective returned negative key rate -1e+250 for params [0.4346536  0.30898803 0.20694663 0.57929254 0.8765765 ], L=1.001001001001001, nx=1000000000
objective returned negative key rate -1e+250 for params [0.43464044 0.30893588 0.20695673 0.57907736 0.8765765 ], L=1.2012012012012012, nx=1000000000
objective returned negative key rate -1e+250 for params [0.43462723 0.30888373 0.20696686 0.57886213 0.8765765 ], L=1.4014014014014016, nx=1000000000
objective r

  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linestyle='-', linewidth=2.0, alpha=1.0)  # Thicker line, semi-transparent
  ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r', label='Predicted Key Rate',


Absolute differences plot saved to absolute_differences_nx_1e+09.png
--- Plotting Time: 1.10 seconds ---
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+04_epoch_0.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+05_epoch_0.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+06_epoch_0.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+07_epoch_0.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+08_epoch_0.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+09_epoch_0.json


  plt.show()


Key rate subplots saved to keyrate_subplots_first_epoch.png


  plt.show()


Parameter subplots saved to parameters_subplots_first_epoch.png
Model saved to bb84_nn_model.pth
Epoch 1/5000, Train Loss: 0.2632, Val Loss: 0.0837, Learning Rate: 0.001000, Time: 18.03s (Eval: 14.72s, Plot: 2.75s)
Epoch 2/5000, Train Loss: 0.0403, Val Loss: 0.0184, Learning Rate: 0.001000, Time: 0.25s (Eval: 0.00s, Plot: 0.00s)
Epoch 3/5000, Train Loss: 0.0111, Val Loss: 0.0072, Learning Rate: 0.001000, Time: 0.23s (Eval: 0.00s, Plot: 0.00s)
Epoch 4/5000, Train Loss: 0.0054, Val Loss: 0.0045, Learning Rate: 0.001000, Time: 0.23s (Eval: 0.00s, Plot: 0.00s)
Epoch 5/5000, Train Loss: 0.0035, Val Loss: 0.0032, Learning Rate: 0.001000, Time: 0.22s (Eval: 0.00s, Plot: 0.00s)
Epoch 6/5000, Train Loss: 0.0025, Val Loss: 0.0023, Learning Rate: 0.001000, Time: 0.23s (Eval: 0.00s, Plot: 0.00s)
Epoch 7/5000, Train Loss: 0.0020, Val Loss: 0.0019, Learning Rate: 0.001000, Time: 0.23s (Eval: 0.00s, Plot: 0.00s)
Epoch 8/5000, Train Loss: 0.0018, Val Loss: 0.0017, Learning Rate: 0.001000, Time: 0.22s 

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+04_epoch_4999.png
--- Plotting Time: 1.17 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()


Absolute differences plot saved to absolute_differences_nx_1e+04.png
--- Plotting Time: 1.15 seconds ---
objective returned negative key rate -1e+250 for params [0.58182997 0.21564907 0.09028985 0.69892424 0.5577087 ], L=0.0, nx=100000
objective returned negative key rate -1e+250 for params [0.5817103  0.21554945 0.09028215 0.6987757  0.5576416 ], L=0.2002002002002002, nx=100000
objective returned negative key rate -1e+250 for params [0.58147097 0.21535023 0.09026685 0.6984784  0.55750734], L=0.6006006006006006, nx=100000
objective returned negative key rate -1e+250 for params [0.58135134 0.21525061 0.09025914 0.69832987 0.5574404 ], L=0.8008008008008008, nx=100000
objective returned negative key rate -1e+250 for params [0.5812317  0.21515101 0.09025147 0.6981812  0.5573732 ], L=1.001001001001001, nx=100000
objective returned negative key rate -1e+250 for params [0.581112   0.21505141 0.09024385 0.69803256 0.5573061 ], L=1.2012012012012012, nx=100000
objective returned negative key rat

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+05_epoch_4999.png
--- Plotting Time: 1.28 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_nx_1e+05.png
--- Plotting Time: 1.13 seconds ---
objective returned negative key rate -1e+250 for params [0.5501555  0.2681113  0.13089064 0.74261904 0.66213644], L=0.0, nx=1000000
objective returned negative key rate -1e+250 for params [0.55004585 0.26806915 0.1308814  0.7425602  0.66214496], L=0.2002002002002002, nx=1000000
objective returned negative key rate -1e+250 for params [0.54982644 0.2679849  0.13086297 0.7424427  0.66216195], L=0.6006006006006006, nx=1000000
objective returned negative key rate -1e+250 for params [0.54971665 0.26794273 0.13085374 0.74238396 0.6621705 ], L=0.8008008008008008, nx=1000000
objective returned negative key rate -1e+250 for params [0.54949725 0.26785845 0.13083531 0.74226636 0.6621875 ], L=1.2012012012012012, nx=1000000
objective returned negative key rate -1e+250 for params [0.5493876  0.26781633 0.13082607 0.7422075  0.662196  ], L=1.4014014014014016, nx=1000000
objective returned negative 

  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()


Absolute differences plot saved to absolute_differences_nx_1e+06.png
--- Plotting Time: 1.11 seconds ---
objective returned negative key rate -1e+250 for params [0.523607   0.30907556 0.14995316 0.7649051  0.7575522 ], L=0.0, nx=10000000
objective returned negative key rate -1e+250 for params [0.52351975 0.30902073 0.14995466 0.7648488  0.7575371 ], L=0.2002002002002002, nx=10000000
objective returned negative key rate -1e+250 for params [0.52343255 0.30896595 0.14995621 0.76479244 0.757522  ], L=0.4004004004004004, nx=10000000
objective returned negative key rate -1e+250 for params [0.52334535 0.30891117 0.14995776 0.76473606 0.75750697], L=0.6006006006006006, nx=10000000
objective returned negative key rate -1e+250 for params [0.5231709 0.3088015 0.1499608 0.7646233 0.7574767], L=1.001001001001001, nx=10000000
objective returned negative key rate -1e+250 for params [0.5230836  0.30874673 0.14996234 0.7645669  0.7574616 ], L=1.2012012012012012, nx=10000000
objective returned negative 

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+07_epoch_4999.png
--- Plotting Time: 1.25 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_nx_1e+07.png
--- Plotting Time: 1.26 seconds ---
objective returned negative key rate -1e+250 for params [0.5047476  0.33991244 0.16675976 0.77275324 0.8199466 ], L=0.0, nx=100000000
objective returned negative key rate -1e+250 for params [0.5046353  0.33986747 0.16677073 0.77275324 0.81997263], L=0.2002002002002002, nx=100000000
objective returned negative key rate -1e+250 for params [0.50452286 0.3398225  0.16678162 0.77275324 0.81999886], L=0.4004004004004004, nx=100000000
objective returned negative key rate -1e+250 for params [0.50441045 0.3397776  0.1667926  0.77275324 0.820025  ], L=0.6006006006006006, nx=100000000
objective returned negative key rate -1e+250 for params [0.5042981  0.3397326  0.16680352 0.77275324 0.8200512 ], L=0.8008008008008008, nx=100000000
objective returned negative key rate -1e+250 for params [0.50418574 0.3396876  0.16681443 0.77275324 0.8200773 ], L=1.001001001001001, nx=100000000
objective returne

  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()


Absolute differences plot saved to absolute_differences_nx_1e+08.png
--- Plotting Time: 1.12 seconds ---
objective returned negative key rate -1e+250 for params [0.48881444 0.36304    0.18532406 0.77275324 0.87579834], L=0.0, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48870486 0.3630018  0.18540026 0.77275324 0.8757886 ], L=0.2002002002002002, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48859522 0.36296362 0.18547645 0.77275324 0.875779  ], L=0.4004004004004004, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48848563 0.3629254  0.18555264 0.77275324 0.87576926], L=0.6006006006006006, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48837605 0.36288726 0.18562882 0.77275324 0.8757596 ], L=0.8008008008008008, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48826644 0.36284903 0.18570499 0.77275324 0.87575   ], L=1.001001001001001, nx=1000000000
objective r

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+09_epoch_4999.png
--- Plotting Time: 1.26 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()
  ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linestyle='-', linewidth=2.0, alpha=1.0)  # Thicker line, semi-transparent
  ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r', label='Predicted Key Rate',


Absolute differences plot saved to absolute_differences_nx_1e+09.png
--- Plotting Time: 1.26 seconds ---
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+04_epoch_4999.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+05_epoch_4999.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+06_epoch_4999.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+07_epoch_4999.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+08_epoch_4999.json
Saved predicted dataset to datasets_json/predicted_keyrate_params_nx_1e+09_epoch_4999.json


  plt.show()


Key rate subplots saved to keyrate_subplots_last_epoch.png


  plt.show()
  ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r:', label='Predicted Key Rate', linewidth=4.0)


Parameter subplots saved to parameters_subplots_last_epoch.png
Key rate subplots with relative error saved to keyrate_subplots_with_rel_error_last_epoch.png
Model saved to bb84_nn_model.pth
Epoch 5000/5000, Train Loss: 0.0002, Val Loss: 0.0002, Learning Rate: 0.000001, Time: 19.91s (Eval: 15.26s, Plot: 4.34s)
--- Total Training Time: 1216.89 seconds ---


  plt.show()


In [28]:
# import time 

# # Start total training time measurement
# total_start_time = time.time()
# # Training loop
# num_epochs = 5000
# for epoch in range(num_epochs):
#     start_time = time.time()

#     # Training phase
#     model.train()
#     running_loss = 0.0
#     for inputs, targets in train_loader:
#         inputs, targets = inputs.to(device), targets.to(device)
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item() * inputs.size(0)
#     epoch_loss = running_loss / len(train_loader.dataset)
#     train_losses.append(epoch_loss)

#     # Validation phase
#     model.eval()
#     val_running_loss = 0.0
#     with torch.no_grad():
#         for inputs, targets in val_loader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, targets)
#             val_running_loss += loss.item() * inputs.size(0)
#     val_loss = val_running_loss / len(val_loader.dataset)
#     val_losses.append(val_loss)

#     # Update learning rate
#     scheduler.step(val_loss)
#     current_lr = optimizer.param_groups[0]['lr']
#     learning_rates.append(current_lr)

#     # Evaluate and plot at the first and last epochs
#     eval_time = 0
#     plot_time = 0
#     if epoch == 0 or epoch == num_epochs - 1:
#         model.eval()
#         with torch.no_grad():
#             # Prepare data for all n_X values
#             plot_data = {}
#             eval_start = time.time()
#             for nx, eval_data in all_evaluation_data.items():
#                 fiber_lengths = eval_data['fiber_lengths']
#                 optimized_params_array = eval_data['optimized_params_array']
#                 optimized_key_rates = eval_data['optimized_key_rates']
#                 X_test_tensor = eval_data['X_test_tensor']

#                 # Compute predicted parameters and key rates
#                 predicted_params_scaled = model(X_test_tensor).cpu().numpy()
#                 predicted_params_scaled = np.clip(predicted_params_scaled, 0, 1)
#                 predicted_params = y_scaler.inverse_transform(predicted_params_scaled)
#                 predicted_params[:, 0] = np.maximum(predicted_params[:, 0], 1e-6)
#                 predicted_params[:, 1] = np.maximum(predicted_params[:, 1], 1e-6)
#                 predicted_params[:, 2] = np.clip(predicted_params[:, 2], 0, 1)
#                 predicted_params[:, 3] = np.clip(predicted_params[:, 3], 0, 1)
#                 predicted_params[:, 4] = np.clip(predicted_params[:, 4], 0, 1)
#                 predicted_key_rates = []
#                 for params, L in zip(predicted_params, fiber_lengths):
#                     key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
#                     predicted_key_rates.append(key_rate)
#                 predicted_key_rates = np.array(predicted_key_rates)

#                 # Store data for plotting
#                 plot_data[nx] = {
#                     'fiber_lengths': fiber_lengths,
#                     'optimized_key_rates': optimized_key_rates,
#                     'predicted_key_rates': predicted_key_rates,
#                     'optimized_params_array': optimized_params_array,
#                     'predicted_params_array': predicted_params
#                 }
#             eval_time += time.time() - eval_start

#             # Plot subplots at the first and last epochs
#             plot_start = time.time()
#             if epoch == 0:
#                 plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_first_epoch.png')
#                 plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_first_epoch.png')
#             if epoch == num_epochs - 1:
#                 plot_keyrate_subplots(plot_data, epoch, f'keyrate_subplots_last_epoch.png')
#                 plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_last_epoch.png')
#             plot_time += time.time() - plot_start

#             torch.save(model.state_dict(), 'bb84_nn_model.pth')
#             print("Model saved to bb84_nn_model.pth")

#             # Plot relative errors
#             plot_relative_errors(
#                 fiber_lengths, optimized_key_rates, optimized_params_array,
#                 predicted_key_rates, predicted_params, epoch,
#                 f'relative_error_nx_{nx:.0e}_epoch_{epoch}.png', nx=nx
#             )
#             # Add absolute differences plot for this n_X
#             plot_absolute_differences_keyrate_and_parameters(
#                 fiber_lengths, optimized_key_rates, optimized_params_array,
#                 predicted_key_rates, predicted_params, epoch,
#                 filename=f'absolute_differences_nx_{nx:.0e}.png', nx=nx
#             )
#     # Print epoch results with timing and learning rate
#     epoch_time = time.time() - start_time
#     print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, "
#           f"Learning Rate: {current_lr:.6f}, Time: {epoch_time:.2f}s (Eval: {eval_time:.2f}s, Plot: {plot_time:.2f}s)")

# # Print total training time
# total_training_time = time.time() - total_start_time
# print(f"--- Total Training Time: {total_training_time:.2f} seconds ---")

In [29]:
# %%
# Plotting the losses with smoothing
def moving_average(data, window_size=5):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

epochs = range(1, len(train_losses) + 1)
window_size = 5
smoothed_train_losses = moving_average(train_losses, window_size)
smoothed_val_losses = moving_average(val_losses, window_size)
smoothed_epochs = range(window_size, len(train_losses) + 1)

In [30]:
plt.figure(figsize=(12, 6))
# plt.plot(smoothed_epochs, smoothed_train_losses, label='Smoothed Training Loss', linestyle='-')
# plt.plot(smoothed_epochs, smoothed_val_losses, label='Smoothed Validation Loss', linestyle='--')
plt.plot(epochs, train_losses, label='Training Loss', linestyle='-', alpha=0.3)
plt.plot(epochs, val_losses, label='Validation Loss', linestyle='--', alpha=0.5)
plt.xlabel('Epochs')
plt.ylim(0, 0.002)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_plot.png', dpi=150)
plt.show()
# plt.close()

print("Training Complete")

# Plotting the learning rates
plt.figure(figsize=(12, 6))
plt.plot(epochs, learning_rates, label='Learning Rate', linestyle='-', color='tab:blue')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.legend()
plt.title('Learning Rate over Epochs')
plt.grid(True)
plt.tight_layout()
plt.savefig('learning_rate_plot.png', dpi=150)
plt.show()
# plt.close()

  plt.show()


Training Complete


  plt.show()


In [31]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming you have your training and validation losses, epochs, and learning rates
# For example, let's create some dummy data:
epochs = np.arange(1, 101)  # Example: 100 epochs
train_losses = np.random.rand(100) * 0.001 + np.linspace(0.0005, 0.0001, 100)  # Example decreasing trend with noise
val_losses = np.random.rand(100) * 0.0005 + np.linspace(0.0007, 0.0002, 100) # Example decreasing trend with noise
learning_rates = np.linspace(0.01, 0.0001, 100)  # Example learning rate decay



# 1. Plotting Loss
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_losses, label='Training Loss', linestyle='-', alpha=0.3)
plt.plot(epochs, val_losses, label='Validation Loss', linestyle='--', alpha=0.5)
plt.xlabel('Epochs')
plt.ylim(0, 0.002)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_plot.png', dpi=150)
plt.show()


# 2. Finding Epochs Where Loss No Longer Decreases (or increases, more generally)
#    Focus on the validation loss, as that indicates generalization
val_loss_diff = np.diff(val_losses) # calculate the differences between consecutive validation loss values
# Find the indices where the difference is greater than a small tolerance (e.g., indicating no more significant decrease)

tolerance = 1e-6  # A small value to define when the loss stops decreasing significantly.  Adjust this.

increasing_epochs = np.where(val_loss_diff > tolerance)[0] + 1 # Add 1 to shift to epoch indices, not diff indices.

# Print the epochs where the validation loss starts to increase/stops decreasing
if len(increasing_epochs) > 0:
    print("Epochs where validation loss is no longer decreasing (or starts increasing):")
    for epoch in increasing_epochs:
        print(f"Epoch {epoch}: Loss = {val_losses[epoch-1]:.6f}") # Display the losses at these epochs.
else:
    print("Validation loss consistently decreases (or does not increase significantly) throughout the training.")

# 3. Plotting Learning Rate
plt.figure(figsize=(12, 6))
plt.plot(epochs, learning_rates, label='Learning Rate', linestyle='-', color='tab:blue')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.legend()
plt.title('Learning Rate over Epochs')
plt.grid(True)
plt.tight_layout()
plt.savefig('learning_rate_plot.png', dpi=150)
plt.show()

print("Training Complete")

  plt.show()


Epochs where validation loss is no longer decreasing (or starts increasing):
Epoch 1: Loss = 0.001059
Epoch 3: Loss = 0.000698
Epoch 6: Loss = 0.000811
Epoch 8: Loss = 0.000665
Epoch 10: Loss = 0.000668
Epoch 12: Loss = 0.000767
Epoch 14: Loss = 0.000874
Epoch 16: Loss = 0.000742
Epoch 18: Loss = 0.000677
Epoch 21: Loss = 0.000615
Epoch 22: Loss = 0.000736
Epoch 24: Loss = 0.000611
Epoch 27: Loss = 0.000635
Epoch 28: Loss = 0.000864
Epoch 31: Loss = 0.000555
Epoch 33: Loss = 0.000805
Epoch 36: Loss = 0.000527
Epoch 38: Loss = 0.000892
Epoch 41: Loss = 0.000842
Epoch 43: Loss = 0.000911
Epoch 47: Loss = 0.000508
Epoch 48: Loss = 0.000868
Epoch 50: Loss = 0.000686
Epoch 52: Loss = 0.000643
Epoch 55: Loss = 0.000501
Epoch 57: Loss = 0.000461
Epoch 58: Loss = 0.000731
Epoch 60: Loss = 0.000526
Epoch 62: Loss = 0.000649
Epoch 64: Loss = 0.000458
Epoch 65: Loss = 0.000489
Epoch 66: Loss = 0.000688
Epoch 68: Loss = 0.000735
Epoch 71: Loss = 0.000490
Epoch 74: Loss = 0.000356
Epoch 75: Loss = 

  plt.show()


In [32]:
import matplotlib.pyplot as plt
import numpy as np # Often needed for numerical operations like min/max

# Assuming epochs, train_losses, val_losses, learning_rates are defined from your training process
# Example dummy data - REPLACE WITH YOUR ACTUAL DATA
epochs = list(range(1, 5001))
# Simulate training loss decreasing
train_losses = [1.0 / (e + 50) + np.random.normal(0, 0.00005) for e in epochs]
# Simulate validation loss decreasing then plateauing/increasing slightly
val_losses = [1.0 / (e + 30) + np.random.normal(0, 0.00008) for e in epochs[:3500]] + \
             [1.0 / (3500 + 30) + np.random.normal(0, 0.0001) + (e - 3500)*1e-7 for e in epochs[3500:]] # Simulate slight increase after 3500

# Ensure losses are positive (common requirement) and scale example data
train_losses = np.maximum(0.0001, np.array(train_losses)).tolist()
val_losses = np.maximum(0.0002, np.array(val_losses)).tolist()

# Simulate learning rate (example: decay)
learning_rates = [0.001 * 0.95**(e/100) for e in epochs]
# --- End of dummy data ---


plt.figure(figsize=(12, 6))
plt.plot(epochs, train_losses, label='Training Loss', linestyle='-', alpha=0.3)
plt.plot(epochs, val_losses, label='Validation Loss', linestyle='--', alpha=0.5)
plt.xlabel('Epochs')
# Use max of data or user's limit for ylim
plt.ylim(0, max(max(train_losses), max(val_losses)) * 1.1 or 0.002)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_plot.png', dpi=150)
plt.show()
# plt.close() # Keep commented if you want the plot to display immediately

print("Training Complete")

# --- New Code to Detect Plateau ---

# Define patience: how many epochs to wait for an improvement before declaring a plateau
patience = 50 # Adjust this value based on how much fluctuation you expect

# --- For Validation Loss ---
print("\nAnalyzing Validation Loss:")
best_val_loss = float('inf')
epochs_no_improve_val = 0
plateau_start_epoch_val = None

# Iterate through validation losses starting from 'patience' epochs in
# to allow for initial fluctuations
for i in range(patience, len(val_losses)):
    current_val_loss = val_losses[i]
    previous_val_loss = val_losses[i-1] # Compare to previous for general trend

    # Check if current loss is NOT significantly better than the best seen so far
    # Use a small epsilon to handle floating point noise
    # We are looking for when it stops decreasing or starts increasing
    # A simple check: is the loss NOT decreasing compared to the best seen?
    if current_val_loss >= best_val_loss - 1e-6: # Use >= to catch plateaus and increases
         epochs_no_improve_val += 1
    else: # Significant improvement
         best_val_loss = current_val_loss
         epochs_no_improve_val = 0
         plateau_start_epoch_val = None # Reset if we see improvement

    # If no significant improvement for 'patience' epochs and we haven't marked a plateau yet
    if epochs_no_improve_val >= patience and plateau_start_epoch_val is None:
         # The plateau started *after* the best loss was seen and improvement stopped
         # The epoch where the 'patience' period of no improvement *started* is the detection point.
         # This is the epoch 'patience' steps BEFORE the current epoch 'i'.
         plateau_start_epoch_val = epochs[i - patience]
         print(f"Validation loss showed no significant improvement for {patience} consecutive epochs starting around Epoch {plateau_start_epoch_val}.")
         # You can break here if you only want the first instance of a plateau
         # break


if plateau_start_epoch_val is None and epochs_no_improve_val > 0:
    # This handles the case where training finishes during a period of no improvement,
    # but the patience limit wasn't fully reached *before* the end.
     print(f"Validation loss showed no significant improvement over the last {epochs_no_improve_val} epochs.")
elif plateau_start_epoch_val is None:
     print("Validation loss was consistently improving or stable throughout training (no plateau of size >= patience detected).")


# --- For Training Loss ---
# Training loss should ideally always decrease slowly. Detecting a plateau
# often means the learning rate is too low or it's stuck in a shallow minimum.
# A simple way is to check the relative change over a window.
print("\nAnalyzing Training Loss:")
window_size = 100 # Look back over this many epochs

plateau_start_epoch_train = None
relative_loss_change_threshold = 1e-4 # Percentage change threshold (0.01%)

for i in range(window_size, len(train_losses)):
    loss_window = train_losses[i - window_size : i + 1]
    loss_at_start = loss_window[0]
    loss_at_end = loss_window[-1]
    avg_loss_in_window = np.mean(loss_window)

    # Check if the decrease over the window is smaller than a threshold
    # We also require the average loss to be above zero to avoid division by zero or noise issues at very low losses
    if avg_loss_in_window > 1e-9 and (loss_at_start - loss_at_end) / avg_loss_in_window < relative_loss_change_threshold:
        plateau_start_epoch_train = epochs[i - window_size]
        print(f"Training loss decrease over {window_size} epochs fell below a relative threshold of {relative_loss_change_threshold:.4f} starting around Epoch {plateau_start_epoch_train}.")
        # break # Uncomment to stop at the first training plateau detection

if plateau_start_epoch_train is None:
     print("Training loss was still decreasing significantly throughout training (no plateau of size >= window_size detected).")


# --- Plotting the learning rates (as provided by user) ---
plt.figure(figsize=(12, 6))
plt.plot(epochs, learning_rates, label='Learning Rate', linestyle='-', color='tab:blue')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.legend()
plt.title('Learning Rate over Epochs')
plt.grid(True)
plt.tight_layout()
plt.savefig('learning_rate_plot.png', dpi=150)
plt.show()
# plt.close() # Keep commented if you want the plot to display immediately

  plt.show()


Training Complete

Analyzing Validation Loss:
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 612.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 728.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 898.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 1010.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 1159.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 1244.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 1355.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 1442.
Validation loss showed no significant improvement for 50 consecutive epochs starting around Epoch 1625.
Validation loss showe

  plt.show()


In [33]:
# %% [markdown]
# ### Test the Trained Model on n_X = 5e8 Dataset

# Load the saved model
model = BB84NN().to(device)
model.load_state_dict(torch.load('bb84_nn_model.pth'))
model.eval()
print("Loaded saved model for testing.")

# Evaluate the model on the n_X = 5e8 dataset
with torch.no_grad():
    predicted_params_scaled_un_seen = model(X_test_tensor_un_seen).cpu().numpy()
    predicted_params_scaled_un_seen = np.clip(predicted_params_scaled_un_seen, 0, 1)
    predicted_params_un_seen = y_scaler.inverse_transform(predicted_params_scaled_un_seen)
    print("Predicted parameters (mu_1, mu_2, P_mu_1, P_mu_2, P_X):")
    print(predicted_params_un_seen[:5])
    predicted_params_un_seen[:, 0] = np.maximum(predicted_params_un_seen[:, 0], 1e-6)
    predicted_params_un_seen[:, 1] = np.maximum(predicted_params_un_seen[:, 1], 1e-6)
    predicted_params_un_seen[:, 2] = np.clip(predicted_params_un_seen[:, 2], 0, 1)
    predicted_params_un_seen[:, 3] = np.clip(predicted_params_un_seen[:, 3], 0, 1)
    predicted_params_un_seen[:, 4] = np.clip(predicted_params_un_seen[:, 4], 0, 1)
    predicted_key_rates_un_seen = []
    valid_indices_pred = []
    for idx, (params, L) in enumerate(zip(predicted_params_un_seen, fiber_lengths_un_seen)):
        key_rate = safe_objective(params, L, target_nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, e_mis=5e-3, P_ap=0, n_event=1)
        if np.isnan(key_rate):
            print(f"NaN detected in predicted_key_rates_un_seen at index {idx} with params {params}, fiber_length={L}")
        else:
            predicted_key_rates_un_seen.append(key_rate)
            valid_indices_pred.append(idx)
    predicted_key_rates_un_seen = np.array(predicted_key_rates_un_seen)

    fiber_lengths_un_seen = fiber_lengths_un_seen[valid_indices_pred]
    optimized_key_rates_un_seen = optimized_key_rates_un_seen[valid_indices_pred]
    optimized_params_array_un_seen = optimized_params_array_un_seen[valid_indices_pred]
    predicted_params_un_seen = predicted_params_un_seen[valid_indices_pred]
    print(f"After filtering NaN in predicted_key_rates_un_seen, {len(predicted_key_rates_un_seen)} data points remain for n_X = 5e8.")

# Debug: Check for NaN in optimized_key_rates_un_seen and predicted_key_rates_un_seen
print("Checking for NaN in optimized_key_rates_un_seen:")
print(f"Number of NaN values: {np.isnan(optimized_key_rates_un_seen).sum()}")
print("Checking for NaN in predicted_key_rates_un_seen:")
print(f"Number of NaN values: {np.isnan(predicted_key_rates_un_seen).sum()}")

Loaded saved model for testing.
Predicted parameters (mu_1, mu_2, P_mu_1, P_mu_2, P_X):
[[0.49174497 0.35559565 0.18052845 0.7699288  0.8551783 ]
 [0.4907131  0.35526213 0.18107088 0.76955384 0.8555565 ]
 [0.48968118 0.35492855 0.18161334 0.7691789  0.8559347 ]
 [0.4886493  0.35459507 0.18215576 0.7688039  0.85631305]
 [0.4876174  0.35426152 0.18269818 0.76842886 0.85669124]]
objective returned negative key rate -1e+250 for params [0.49174497 0.35559565 0.18052845 0.7699288  0.8551783 ], L=0.0, nx=500000000
objective returned negative key rate -1e+250 for params [0.4907131  0.35526213 0.18107088 0.76955384 0.8555565 ], L=2.0202020202020203, nx=500000000
objective returned negative key rate -1e+250 for params [0.48968118 0.35492855 0.18161334 0.7691789  0.8559347 ], L=4.040404040404041, nx=500000000
objective returned negative key rate -1e+250 for params [0.4886493  0.35459507 0.18215576 0.7688039  0.85631305], L=6.0606060606060606, nx=500000000
objective returned negative key rate -1e+

In [34]:
def plot_keyrate_and_parameters(fiber_lengths, optimized_key_rates, optimized_params_array, 
                                predicted_key_rates, predicted_params_array, epoch, filename, 
                                learning_rates=None, nx=None):
    """
    Plots key rates and parameters for a specific n_X value.
    
    Parameters:
    - fiber_lengths: Array of fiber lengths.
    - optimized_key_rates: Array of optimized key rates.
- optimized_params_array: Array of optimized parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - predicted_key_rates: Array of predicted key rates.
    - predicted_params_array: Array of predicted parameters [mu_1, mu_2, P_mu_1, P_mu_2, P_X].
    - epoch: Epoch number or string (e.g., 'final_test').
    - filename: Output filename for the plot.
    - learning_rates: List of learning rates over epochs (optional, not used).
    - nx: The n_X value for the plot title (optional).
    """

    # Start timing
    plot_start_time = time.time()

    # Create a figure with 2 subplots: key rate (left), all parameters (right)
    fig = plt.figure(figsize=(15, 6))  # Adjusted height for horizontal layout

    # Define the grid layout: 1 row, 2 columns
    gs = fig.add_gridspec(1, 2, width_ratios=[1, 1])

    # Key Rate Plot (Left)
    ax_keyrate = fig.add_subplot(gs[0, 0])

    # Filter out NaN values
    valid_mask = ~(np.isnan(optimized_key_rates) | np.isnan(predicted_key_rates))
    fiber_lengths = fiber_lengths[valid_mask]
    optimized_key_rates = optimized_key_rates[valid_mask]
    predicted_key_rates = predicted_key_rates[valid_mask]
    optimized_params_array = optimized_params_array[valid_mask]
    predicted_params_array = predicted_params_array[valid_mask]

    # Find cutoff where optimized key rate becomes very small
    threshold = 1e-8
    cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
    if len(cutoff_idx) > 0:
        cutoff_idx = cutoff_idx[0]
    else:
        cutoff_idx = len(fiber_lengths)

    fiber_lengths = fiber_lengths[:cutoff_idx]
    optimized_key_rates = optimized_key_rates[:cutoff_idx]
    predicted_key_rates = predicted_key_rates[:cutoff_idx]
    optimized_params_array = optimized_params_array[:cutoff_idx]
    predicted_params_array = predicted_params_array[:cutoff_idx]

    # Plot key rates with distinct linestyles and increased thickness for predicted
    
    ax_keyrate.plot(fiber_lengths, np.log10(predicted_key_rates), 'r--', label='Predicted Key Rate', linewidth=4.0, alpha=0.7)
    ax_keyrate.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linewidth=2.0, alpha=1.0)
    exponent = int(np.log10(nx))

    ax_keyrate.set_title(f"Key Rates for $n_X = 10^{{{exponent}}}$") # (Epoch {epoch})
    ax_keyrate.set_xlabel("Fiber Length (km)")
    ax_keyrate.set_ylabel("Secret Key Rate per Pulse")
    ax_keyrate.grid(True)
    ax_keyrate.legend(loc='upper right')

    # Parameter Plot (Right) - Combine probabilities and intensities
    ax_params = fig.add_subplot(gs[0, 1])

    param_labels = ['$\mu_1$', '$\mu_2$', '$P_{\mu_1}$', '$P_{\mu_2}$', '$P_X$']
    colors = ['red', 'purple', 'orange', 'green', 'blue']
    param_indices = [0, 1, 2, 3, 4]  # Indices for mu_1, mu_2, P_mu_1, P_mu_2, P_X  

    # Plot all parameters
    for param_idx, (label, color) in zip(param_indices, zip(param_labels, colors)):
        ax_params.plot(fiber_lengths, predicted_params_array[:, param_idx], label=f'Predicted {label}', color=color, linestyle='--', linewidth=4.0, alpha=0.7)
        ax_params.plot(fiber_lengths, optimized_params_array[:, param_idx], label=f'Optimized {label}', color=color, linestyle='-', linewidth=2.0,alpha=1.0)
        
     # Compute the exponent for nx (e.g., nx = 10000 -> exponent = 4)
    exponent = int(np.log10(nx))
    ax_params.set_title(f"Parameters for $n_X = 5 * 10^{{{exponent}}}$") # (Epoch {epoch})
    ax_params.set_xlabel("Fiber Length (km)")
    ax_params.set_ylabel("Parameter Value")
    ax_params.set_ylim(0.0, 1.0)
    ax_params.grid(True)

    # Move legend outside the plot to avoid overlap
    ax_params.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    # Adjust layout to prevent overlap
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close(fig)
    print(f"Comparison plot saved to {filename}")


# Plot the comparison for n_X = 5e8
plot_start_time = time.time()
plot_keyrate_and_parameters(
    fiber_lengths_un_seen, optimized_key_rates_un_seen, optimized_params_array_un_seen,
    predicted_key_rates_un_seen, predicted_params_un_seen, epoch='final_test',
    filename='keyrate_parameters_5e8.png', learning_rates=learning_rates, nx=target_nx
)

print(f"--- Final Test Plotting Time (keyrate_and_parameters): {time.time() - plot_start_time:.2f} seconds ---")

# Plot relative errors for n_X = 5e8
plot_relative_errors(
    fiber_lengths_un_seen, optimized_key_rates_un_seen, optimized_params_array_un_seen,
    predicted_key_rates_un_seen, predicted_params_un_seen, epoch='final_test',
    filename=f'relative_error_nx_{target_nx:.0e}.png', nx=target_nx
)

# Add the new plot here
plot_absolute_differences_keyrate_and_parameters(
    fiber_lengths_un_seen, optimized_key_rates_un_seen, optimized_params_array_un_seen,
    predicted_key_rates_un_seen, predicted_params_un_seen, epoch='final_test',
    filename='absolute_differences_5e8.png', nx=target_nx
)


  ax_keyrate.plot(fiber_lengths, np.log10(predicted_key_rates), 'r--', label='Predicted Key Rate', linewidth=4.0, alpha=0.7)
  plt.show()
  plt.tight_layout(rect=[0, 0, 1, 0.95])


Comparison plot saved to keyrate_parameters_5e8.png
--- Final Test Plotting Time (keyrate_and_parameters): 0.69 seconds ---
Relative error plot saved to relative_error_nx_5e+08.png
--- Plotting Time: 1.10 seconds ---


  plt.tight_layout(rect=[0, 0, 1, 0.95])


Absolute differences plot saved to absolute_differences_5e8.png
--- Plotting Time: 1.26 seconds ---


In [35]:
# Evaluate and plot for all n_X values (10^4 to 10^9)
for nx in all_evaluation_data.keys():
    eval_data = all_evaluation_data[nx]
    fiber_lengths = eval_data['fiber_lengths']
    optimized_params_array = eval_data['optimized_params_array']
    optimized_key_rates = eval_data['optimized_key_rates']
    X_test_tensor = eval_data['X_test_tensor']

    with torch.no_grad():
        predicted_params_scaled = model(X_test_tensor).cpu().numpy()
        predicted_params_scaled = np.clip(predicted_params_scaled, 0, 1)
        predicted_params_array = y_scaler.inverse_transform(predicted_params_scaled)
        predicted_params_array[:, 0] = np.maximum(predicted_params_array[:, 0], 1e-6)
        predicted_params_array[:, 1] = np.maximum(predicted_params_array[:, 1], 1e-6)
        predicted_params_array[:, 2] = np.clip(predicted_params_array[:, 2], 0, 1)
        predicted_params_array[:, 3] = np.clip(predicted_params_array[:, 3], 0, 1)
        predicted_params_array[:, 4] = np.clip(predicted_params_array[:, 4], 0, 1)
        predicted_key_rates = []
        for params, L in zip(predicted_params_array, fiber_lengths):
            key_rate = safe_objective(params, L, nx, alpha=0.2, eta_Bob=0.1, P_dc_value=6e-7, 
                                     epsilon_sec=1e-10, epsilon_cor=1e-15, f_EC=1.16, 
                                     e_mis=5e-3, P_ap=0, n_event=1)
            predicted_key_rates.append(key_rate)
        predicted_key_rates = np.array(predicted_key_rates)

    # Plot relative errors for this n_X
    plot_start_time = time.time()
    plot_relative_errors(
        fiber_lengths, optimized_key_rates, optimized_params_array,
        predicted_key_rates, predicted_params_array, epoch='final_test',
        filename=f'relative_error_nx_{nx:.0e}.png', nx=nx
    )
    print(f"--- Final Test Plotting Time (relative_error_nx_{nx:.0e}): {time.time() - plot_start_time:.2f} seconds ---")

objective returned negative key rate -1e+250 for params [0.6072972  0.14025092 0.04766058 0.6189414  0.40844586], L=0.0, nx=10000
objective returned negative key rate -1e+250 for params [0.6069195  0.14003786 0.0475364  0.61850834 0.40851188], L=0.6006006006006006, nx=10000
objective returned negative key rate -1e+250 for params [0.6067934  0.1399715  0.04749403 0.6183753  0.40855283], L=0.8008008008008008, nx=10000
objective returned negative key rate -1e+250 for params [0.60666716 0.13990508 0.04745167 0.61824214 0.4085935 ], L=1.001001001001001, nx=10000
objective returned negative key rate -1e+250 for params [0.606541   0.13983876 0.04740933 0.61810905 0.40863457], L=1.2012012012012012, nx=10000
objective returned negative key rate -1e+250 for params [0.60641474 0.13977239 0.04736705 0.61797595 0.4086754 ], L=1.4014014014014016, nx=10000
objective returned negative key rate -1e+250 for params [0.60628855 0.13970603 0.04732463 0.6178429  0.4087164 ], L=1.6016016016016017, nx=10000
o

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+04.png
--- Plotting Time: 1.13 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+04): 1.13 seconds ---
objective returned negative key rate -1e+250 for params [0.58182997 0.21564907 0.09028985 0.69892424 0.5577087 ], L=0.0, nx=100000
objective returned negative key rate -1e+250 for params [0.5817103  0.21554945 0.09028215 0.6987757  0.5576416 ], L=0.2002002002002002, nx=100000
objective returned negative key rate -1e+250 for params [0.58147097 0.21535023 0.09026685 0.6984784  0.55750734], L=0.6006006006006006, nx=100000
objective returned negative key rate -1e+250 for params [0.58135134 0.21525061 0.09025914 0.69832987 0.5574404 ], L=0.8008008008008008, nx=100000
objective returned negative key rate -1e+250 for params [0.5812317  0.21515101 0.09025147 0.6981812  0.5573732 ], L=1.001001001001001, nx=100000
objective returned negative key rate -1e+250 for params [0.581112   0.21505141 0.09024385 0.69803256 0.5573061 ], L=1.20

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+05.png
--- Plotting Time: 1.25 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+05): 1.25 seconds ---
objective returned negative key rate -1e+250 for params [0.5501555  0.2681113  0.13089064 0.74261904 0.66213644], L=0.0, nx=1000000
objective returned negative key rate -1e+250 for params [0.55004585 0.26806915 0.1308814  0.7425602  0.66214496], L=0.2002002002002002, nx=1000000
objective returned negative key rate -1e+250 for params [0.54982644 0.2679849  0.13086297 0.7424427  0.66216195], L=0.6006006006006006, nx=1000000
objective returned negative key rate -1e+250 for params [0.54971665 0.26794273 0.13085374 0.74238396 0.6621705 ], L=0.8008008008008008, nx=1000000
objective returned negative key rate -1e+250 for params [0.54949725 0.26785845 0.13083531 0.74226636 0.6621875 ], L=1.2012012012012012, nx=1000000
objective returned negative key rate -1e+250 for params [0.5493876  0.26781633 0.13082607 0.7422075  0.662196  ], 

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+06.png
--- Plotting Time: 1.13 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+06): 1.13 seconds ---
objective returned negative key rate -1e+250 for params [0.523607   0.30907556 0.14995316 0.7649051  0.7575522 ], L=0.0, nx=10000000
objective returned negative key rate -1e+250 for params [0.52351975 0.30902073 0.14995466 0.7648488  0.7575371 ], L=0.2002002002002002, nx=10000000
objective returned negative key rate -1e+250 for params [0.52343255 0.30896595 0.14995621 0.76479244 0.757522  ], L=0.4004004004004004, nx=10000000
objective returned negative key rate -1e+250 for params [0.52334535 0.30891117 0.14995776 0.76473606 0.75750697], L=0.6006006006006006, nx=10000000
objective returned negative key rate -1e+250 for params [0.5231709 0.3088015 0.1499608 0.7646233 0.7574767], L=1.001001001001001, nx=10000000
objective returned negative key rate -1e+250 for params [0.5230836  0.30874673 0.14996234 0.7645669  0.7574616 ], L

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+07.png
--- Plotting Time: 1.26 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+07): 1.26 seconds ---
objective returned negative key rate -1e+250 for params [0.5047476  0.33991244 0.16675976 0.77275324 0.8199466 ], L=0.0, nx=100000000
objective returned negative key rate -1e+250 for params [0.5046353  0.33986747 0.16677073 0.77275324 0.81997263], L=0.2002002002002002, nx=100000000
objective returned negative key rate -1e+250 for params [0.50452286 0.3398225  0.16678162 0.77275324 0.81999886], L=0.4004004004004004, nx=100000000
objective returned negative key rate -1e+250 for params [0.50441045 0.3397776  0.1667926  0.77275324 0.820025  ], L=0.6006006006006006, nx=100000000
objective returned negative key rate -1e+250 for params [0.5042981  0.3397326  0.16680352 0.77275324 0.8200512 ], L=0.8008008008008008, nx=100000000
objective returned negative key rate -1e+250 for params [0.50418574 0.3396876  0.16681443 0.77275324 0.8

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+08.png
--- Plotting Time: 1.13 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+08): 1.13 seconds ---
objective returned negative key rate -1e+250 for params [0.48881444 0.36304    0.18532406 0.77275324 0.87579834], L=0.0, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48870486 0.3630018  0.18540026 0.77275324 0.8757886 ], L=0.2002002002002002, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48859522 0.36296362 0.18547645 0.77275324 0.875779  ], L=0.4004004004004004, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48848563 0.3629254  0.18555264 0.77275324 0.87576926], L=0.6006006006006006, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48837605 0.36288726 0.18562882 0.77275324 0.8757596 ], L=0.8008008008008008, nx=1000000000
objective returned negative key rate -1e+250 for params [0.48826644 0.36284903 0.18570499 0.7727532

  plt.tight_layout(rect=[0, 0, 1, 0.95])


Relative error plot saved to relative_error_nx_1e+09.png
--- Plotting Time: 1.13 seconds ---
--- Final Test Plotting Time (relative_error_nx_1e+09): 1.13 seconds ---


In [36]:
# Compute Mean Absolute Error (MAE) for parameters and key rates
mae_params = mean_absolute_error(optimized_params_array_un_seen, predicted_params_un_seen, multioutput='raw_values')
param_labels = ['$mu_1$', '$mu_2$', '$P_{mu_1}$', '$P_{mu_2}$', '$P_X$']
for label, mae in zip(param_labels, mae_params):
    print(f"MAE for {label}: {mae:.6f}")

# Compute MAE for key rates, handling NaN values
valid_mask = ~(np.isnan(optimized_key_rates_un_seen) | np.isnan(predicted_key_rates_un_seen))
if not np.any(valid_mask):
    print("Warning: All key rates are NaN. Cannot compute MAE for key rates.")
    mae_key_rates = float('nan')
else:
    mae_key_rates = mean_absolute_error(
        optimized_key_rates_un_seen[valid_mask],
        predicted_key_rates_un_seen[valid_mask]
    )
    print(f"Computed MAE for Key Rates using {np.sum(valid_mask)} valid data points.")
print(f"MAE for Key Rates: {mae_key_rates:.6f}")

MAE for $mu_1$: 0.018956
MAE for $mu_2$: 0.014681
MAE for $P_{mu_1}$: 0.040708
MAE for $P_{mu_2}$: 0.028440
MAE for $P_X$: 0.013721
Computed MAE for Key Rates using 100 valid data points.
MAE for Key Rates: 0.001294


In [37]:
def plot_keyrate_subplots_with_relative_error(all_data, epoch, filename):
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True)
    fig.suptitle(f"Key Rates and Relative Error vs Fiber Length at Epoch {epoch+1} ($n_X = 10^s, s=4,5,...,9$)", fontsize=16)

    axes = axes.flatten()
    nx_values = sorted(all_data.keys())

    for i, nx in enumerate(nx_values):
        if i >= 6:
            break

        ax = axes[i]
        fiber_lengths = all_data[nx]['fiber_lengths']
        optimized_key_rates = all_data[nx]['optimized_key_rates']
        predicted_key_rates = all_data[nx].get('predicted_key_rates', None)

        valid_mask = ~(np.isnan(optimized_key_rates))
        if predicted_key_rates is not None:
            valid_mask &= ~np.isnan(predicted_key_rates)
        fiber_lengths = fiber_lengths[valid_mask]
        optimized_key_rates = optimized_key_rates[valid_mask]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[valid_mask]

        threshold = 1e-8
        cutoff_idx = np.where(optimized_key_rates <= threshold)[0]
        if len(cutoff_idx) > 0:
            cutoff_idx = cutoff_idx[0]
        else:
            cutoff_idx = len(fiber_lengths)

        fiber_lengths = fiber_lengths[:cutoff_idx]
        optimized_key_rates = optimized_key_rates[:cutoff_idx]
        if predicted_key_rates is not None:
            predicted_key_rates = predicted_key_rates[:cutoff_idx]

        # Plot key rates
        ax.plot(fiber_lengths, np.log10(optimized_key_rates), 'b-', label='Optimized Key Rate', linewidth=2.0)
        if predicted_key_rates is not None:
            ax.plot(fiber_lengths, np.log10(predicted_key_rates), 'r:', label='Predicted Key Rate', linewidth=4.0)

        # Plot relative error on a secondary axis
        ax2 = ax.twinx()
        denominator = np.maximum(optimized_key_rates, threshold)
        rel_error = (predicted_key_rates - optimized_key_rates) / denominator
        rel_error = np.clip(rel_error, -10, 10)
        ax2.plot(fiber_lengths, rel_error, 'g--', label='Relative Error', alpha=0.5)
        ax2.set_ylabel('Relative Error', color='g')
        ax2.tick_params(axis='y', labelcolor='g')
        ax2.set_ylim(-10, 10)

        exponent = int(np.log10(nx))
        ax.set_title(f"$n_X = 10^{{{exponent}}}$")
        ax.grid(True)

        if i >= 3:
            ax.set_xlabel("Fiber Length (km)")
        if i % 3 == 0:
            ax.set_ylabel("Secret Key Rate per Pulse")
        ax.legend(loc='upper left')
        ax2.legend(loc='upper right')

    for i in range(len(nx_values), 6):
        fig.delaxes(axes[i])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close(fig)
    print(f"Key rate subplots with relative error saved to {filename}")

In [38]:
if epoch == num_epochs - 1:
    plot_keyrate_subplots_with_relative_error(plot_data, epoch, f'keyrate_subplots_with_rel_error_last_epoch.png')
    plot_parameters_subplots(plot_data, epoch, f'parameters_subplots_last_epoch.png')