In [20]:
import os
import re
import pandas as pd
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
import matplotlib as mpl
import warnings
import math

# --- Configuration ---
# Make sure this points to the correct directory containing the distance experiment results
RESULTS_DIR = "distance_qec_fiber_loss"
PLOTS_DIR = "plots_distance_qec_fiber_loss"
os.makedirs(PLOTS_DIR, exist_ok=True)

# --- Plotting Configuration (Copied from single_qubit_analysis.py) ---
def configure_plots():
    """Configure matplotlib rcParams for publication-quality plots"""
    plt.style.use('seaborn-v0_8-whitegrid') # Use an available style
    mpl.rcParams.update({
        'font.family': 'serif',
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 10, # Adjusted legend size
        'legend.frameon': True,
        'legend.framealpha': 0.8,
        'figure.figsize': (12, 7), # Slightly wider default
        'figure.dpi': 300,
        'text.usetex': False,  # Keep False for broader compatibility
        'mathtext.fontset': 'stix',
        'axes.prop_cycle': plt.cycler('color', plt.cm.viridis(np.linspace(0, 0.85, 8))),
        'axes.axisbelow': True,
        'grid.alpha': 0.6,
        'grid.linestyle': ':'
    })

configure_plots() # Apply the configuration
warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')
warnings.filterwarnings('ignore', category=RuntimeWarning) # Ignore potential runtime warnings from mean of empty slice

# --- Helper Functions ---
def format_gamma(value, precision=4):
    """Formats gamma values, truncating long decimals."""
    if isinstance(value, (int, float)):
        if abs(value - round(value)) < 1e-9: # Check if it's effectively an integer
             return str(int(value))
        else:
             # Truncate instead of round
             factor = 10.0 ** precision
             return str(math.trunc(value * factor) / factor)
             # Alternative: standard formatting
             # return f"{value:.{precision}g}"
    return str(value)

def parse_param_string_distance(param_str):
    """Parses complex parameter strings from distance experiments.
       Handles p_loss_init, p_loss_length, sec_gamma, ter_gamma etc.
    """
    params = {}
    parts = param_str.split('_')
    i = 0
    current_key = ""
    while i < len(parts):
        part = parts[i]
        # Check if the part looks like a value (numeric, possibly with decimal or E notation)
        is_value = False
        try:
            float(part)
            is_value = True
        except ValueError:
            is_value = False

        if not is_value and current_key:
            # If we have a key and the current part isn't a value, it's part of the key
            current_key += "_" + part
            i += 1
        elif not is_value and not current_key:
            # Starting a new key
            current_key = part
            i += 1
        elif is_value and current_key:
            # Found the value for the current key
            try:
                 params[current_key] = float(part)
            except ValueError:
                 params[current_key] = part # Keep as string if not float
            current_key = "" # Reset key
            i += 1
        else:
            # Edge case: value without a preceding key? Or starting with a value?
            # Treat as part of the next key or ignore if error
            # print(f"Warning: Unexpected part '{part}' in param string: {param_str}")
            i += 1 # Move on

    # Rename keys for clarity if needed (optional)
    final_params = {}
    for k, v in params.items():
        if k.startswith("sec_"):
            final_params[k] = v # Keep prefix for now
        elif k.startswith("ter_"):
            final_params[k] = v # Keep prefix for now
        else:
            final_params[k] = v

    return final_params

def format_params_dict(params_dict):
    """Creates a readable string representation of the params dict, formatting gamma."""
    if not params_dict:
        return "No Params"
    items = []
    for k, v in sorted(params_dict.items()):
        if 'gamma' in k:
            items.append(f"{k}={format_gamma(v)}")
        else:
            items.append(f"{k}={v}")
    return ", ".join(items)

# --- Data Loading Function --- 
def load_distance_qec_data(results_dir=RESULTS_DIR):
    """
    Loads experiment data from the distance_qec_network results structure.
    Extracts metrics from metadata files and calculates loss-adjusted fidelity.
    """
    print(f"Loading distance experiment data from: {results_dir}")
    all_experiments_data = []
    skipped_files = []
    loaded_count = 0
    file_parse_errors = []

    # Regex to capture: state, qec_method, distance, params_string
    # Example: +_none_d10.0_p_loss_init_0.05_p_loss_length_0.16_sec_gamma_0.01_ter_gamma_0.2.csv
    filename_pattern = re.compile(r"^([+\-01])_(.+?)_d([\d\.]+?)_(.+)\.csv$")

    # Iterate through error combination directories
    error_combo_dirs = [d for d in os.listdir(results_dir) if os.path.isdir(os.path.join(results_dir, d))]

    for error_combo in error_combo_dirs:
        dir_path = os.path.join(results_dir, error_combo)
        # print(f" Processing directory: {dir_path}") # Reduced verbosity

        for filename in os.listdir(dir_path):
            # Focus ONLY on metadata files for primary data extraction
            if not filename.endswith("_metadata.csv"):
                continue

            base_filename = filename.replace("_metadata.csv", ".csv")
            match = filename_pattern.match(base_filename)

            if match:
                initial_state = match.group(1)
                qec_method = match.group(2)
                distance = float(match.group(3))
                param_string = match.group(4)
                metadata_path = os.path.join(dir_path, filename)

                try:
                    # Parse the parameter string first
                    error_params = parse_param_string_distance(param_string)
                    if not error_params:
                        # print(f"  - Warning: Could not parse parameters from '{param_string}' in {base_filename}")
                        skipped_files.append(base_filename)
                        continue

                    # Load metadata
                    metadata_df = pd.read_csv(metadata_path)
                    if metadata_df.empty:
                        # print(f"  - Warning: Empty metadata file: {filename}")
                        skipped_files.append(base_filename)
                        continue
                    metadata = metadata_df.iloc[0].to_dict()

                    # Extract required fields from metadata
                    raw_avg_fidelity = metadata.get('avg_fidelity', 0.0)
                    raw_std_fidelity = metadata.get('std_fidelity', 0.0)
                    total_runs = metadata.get('iterations_completed', 0)
                    loss_count = metadata.get('loss_count', 0)

                    # Calculate loss-adjusted metrics
                    loss_adjusted_avg_fidelity = 0.0
                    loss_ratio = 0.0
                    successful_runs = total_runs - loss_count

                    if total_runs > 0:
                        loss_ratio = loss_count / total_runs
                        # Adjust fidelity: assume lost runs have fidelity 0
                        # Also handle case where raw_avg_fidelity is NaN (if successful_runs is 0)
                        if successful_runs > 0 and pd.notna(raw_avg_fidelity):
                             loss_adjusted_avg_fidelity = (raw_avg_fidelity * successful_runs) / total_runs
                        else:
                             loss_adjusted_avg_fidelity = 0.0 # If no successful runs or raw fidelity is NaN

                    else:
                         # Handle case with 0 iterations completed? Set fidelities to 0.
                         raw_avg_fidelity = 0.0
                         raw_std_fidelity = 0.0
                         loss_adjusted_avg_fidelity = 0.0

                    experiment_entry = {
                         "initial_state": initial_state,
                         "qec_method": qec_method,
                         "distance": distance,
                         "error_combo": error_combo, # Store the directory name as the error combo identifier
                         "error_params": error_params,
                         "param_string": format_params_dict(error_params), # Store readable param string
                         "avg_fidelity_raw": raw_avg_fidelity,
                         "std_fidelity_raw": raw_std_fidelity, # Std dev of individual trial fidelities (from metadata)
                         "avg_fidelity_adj": loss_adjusted_avg_fidelity,
                         "loss_ratio": loss_ratio,
                         "loss_count": loss_count,
                         "total_runs": total_runs,
                         "filename_base": base_filename
                         # We are not loading the full data CSV by default here
                    }

                    all_experiments_data.append(experiment_entry)
                    loaded_count += 1

                except Exception as e:
                    print(f"  - Error processing metadata file {filename}: {e}")
                    import traceback
                    traceback.print_exc() # Print traceback for debugging loading errors
                    file_parse_errors.append((filename, str(e)))
                    skipped_files.append(base_filename)
            else:
                 # Metadata filename doesn't match expected base pattern
                 skipped_files.append(filename)

    print(f"Finished loading. Loaded {loaded_count} experiment configurations from metadata.")
    if skipped_files:
        print(f"Skipped {len(skipped_files)} files/metadata (check format/content/errors).")
    if file_parse_errors:
        print(f"Encountered {len(file_parse_errors)} errors during file processing:")
        for fname, err in file_parse_errors[:10]:
             print(f"  - {fname}: {err}")
        if len(file_parse_errors) > 10:
             print("  ... (additional errors truncated)")
    return all_experiments_data

# ===========================================
# --- NEW Plotting Functions Implementation ---
# ===========================================

# --- Mappings for Plotting ---
ERROR_COMBO_MAP = {
    'fibre_loss': 'Fiber',
    'fibre_loss_plus_amplitude_damping': 'Fiber+Amp',
    'fibre_loss_plus_phase_damping': 'Fiber+Phase',
    'fibre_loss_plus_amplitude_damping_plus_phase_damping': 'Fiber+Amp+Phase'
}

METHOD_DISPLAY_MAP = {
    'none': 'No QEC',
    'three_qubit_phase_flip': '3QB Phase',
    'shor_nine': 'Shor-9' # Add other methods if present
}

# --- Plot 1: Fidelity (Raw & Adjusted) vs Error Combination ---
def plot_fidelity_by_error_combo(df, plots_subdir="1_fidelity_vs_error_combo"):
    """
    Plots avg fidelity (raw & adjusted) vs. error combination.
    Averages over distances, QEC methods, initial states, and specific parameters.
    Includes std dev of the *mean* fidelities across aggregated groups.
    """
    print("\n--- Plot 1: Fidelity vs Error Combination ---")
    plot_dir = os.path.join(PLOTS_DIR, plots_subdir)
    os.makedirs(plot_dir, exist_ok=True)

    # Aggregate fidelities, calculating mean and std dev of the means across runs
    agg_data = df.groupby('error_combo').agg(
        mean_raw=('avg_fidelity_raw', 'mean'),
        std_raw=('avg_fidelity_raw', 'std'), # Std dev of the mean fidelities from different configs
        mean_adj=('avg_fidelity_adj', 'mean'),
        std_adj=('avg_fidelity_adj', 'std')  # Std dev of the mean adjusted fidelities
    ).reset_index()

    # Sort for consistent plotting order (optional)
    agg_data = agg_data.sort_values(by='mean_adj', ascending=False)

    if agg_data.empty:
        print("  No data to plot for Fidelity vs Error Combination.")
        return

    n_combos = len(agg_data['error_combo'])
    index = np.arange(n_combos)
    bar_width = 0.35

    fig, ax = plt.subplots(figsize=(max(10, n_combos * 1.5), 7)) # Adjust width based on number of combos

    # Plot bars
    bars_raw = ax.bar(index - bar_width/2, agg_data['mean_raw'], bar_width,
                      yerr=agg_data['std_raw'], label=r'Raw Fidelity',
                      capsize=4, alpha=0.85, edgecolor='black', linewidth=0.5)
    bars_adj = ax.bar(index + bar_width/2, agg_data['mean_adj'], bar_width,
                      yerr=agg_data['std_adj'], label=r'Loss-Adjusted Fidelity',
                      capsize=4, alpha=0.85, edgecolor='black', linewidth=0.5)

    # Add labels and title with LaTeX
    ax.set_xlabel(r'Error Combination')
    ax.set_ylabel(r'Average Fidelity $F$')
    ax.set_title(r'Overall Average Fidelity by Error Combination')
    ax.set_xticks(index)
    # Use the mapped short names for labels
    ax.set_xticklabels(agg_data['error_combo'], rotation=30, ha='right')
    ax.legend(loc='upper right') # Changed from 'best'
    ax.set_ylim(0, 1.05)
    ax.grid(axis='y', linestyle=':', alpha=0.7)
    ax.yaxis.set_minor_locator(AutoMinorLocator(2))

    # Add value labels on top of bars (optional, can be cluttered)
    # ax.bar_label(bars_raw, fmt=r'%.2f', padding=3, fontsize=8)
    # ax.bar_label(bars_adj, fmt=r'%.2f', padding=3, fontsize=8)

    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "avg_fidelity_by_error_combo.png"))
    plt.close()

# --- Plot 2: Fidelity/Loss vs Fiber Parameters --- (REMOVED)
# def plot_fidelity_loss_vs_fiber_params(df, plots_subdir="2_vs_fiber_params"):
#     ...

# --- Plot 3: Fidelity (Raw & Adjusted) vs Distance (Combined) ---
def plot_fidelity_vs_distance_combined(df, plots_subdir="2_fidelity_vs_distance_combined"):
    """Creates a single plot showing fidelity vs distance for all error combinations."""
    print("\n--- Plotting Fidelity vs. Distance (Combined) ---")
    plot_dir = os.path.join(PLOTS_DIR, plots_subdir)
    os.makedirs(plot_dir, exist_ok=True)
    
    # Filter out too-few points
    distances = sorted(df['distance'].unique())
    if len(distances) <= 1:
        print("  Not enough distance points to create plot.")
        return
    
    # Create a single plot instead of 4 subplots
    plt.figure(figsize=(12, 8))
    ax = plt.gca()
    
    # Set up color scheme by error combination
    error_combos = sorted(df['error_combo'].unique())
    colors = plt.cm.viridis(np.linspace(0, 0.85, len(error_combos)))
    error_colors = {combo: colors[i] for i, combo in enumerate(error_combos)}
    
    # Set up marker scheme by QEC method
    qec_methods = sorted(df['qec_method'].unique())
    markers = ['o', 's', '^', 'd', '*']
    marker_map = {qec: markers[i % len(markers)] for i, qec in enumerate(qec_methods)}
    
    # Track lines for legend
    error_lines = {}
    method_markers = {}
    
    # Plot each error combination and QEC method
    for error_combo in error_combos:
        for qec_method in qec_methods:
            subset = df[(df['error_combo'] == error_combo) & 
                        (df['qec_method'] == qec_method)]
            
            if subset.empty:
                continue
                
            # Average across other parameters for each distance
            avg_by_dist = subset.groupby('distance').agg(
                avg_fid_raw=('avg_fidelity_raw', 'mean'),
                std_fid_raw=('avg_fidelity_raw', 'std'),
                avg_fid_adj=('avg_fidelity_adj', 'mean'),
                std_fid_adj=('avg_fidelity_adj', 'std')
            ).reset_index()
            
            if avg_by_dist.empty or len(avg_by_dist) < 2:
                continue
                
            # Plot raw fidelity (solid line)
            line_raw, = ax.plot(avg_by_dist['distance'], avg_by_dist['avg_fid_raw'],
                             marker=marker_map[qec_method], linestyle='-', linewidth=2,
                             color=error_colors[error_combo], markersize=8,
                             label=f"{ERROR_COMBO_MAP.get(error_combo, error_combo)} - {METHOD_DISPLAY_MAP.get(qec_method, qec_method)} (Raw)")
            
            # Plot adjusted fidelity (dashed line)
            line_adj, = ax.plot(avg_by_dist['distance'], avg_by_dist['avg_fid_adj'],
                             marker=marker_map[qec_method], linestyle='--', linewidth=2, 
                             color=error_colors[error_combo], markersize=8,
                             label=f"{ERROR_COMBO_MAP.get(error_combo, error_combo)} - {METHOD_DISPLAY_MAP.get(qec_method, qec_method)} (Adj.)")
            
            # Add error bars
            ax.fill_between(avg_by_dist['distance'], 
                          avg_by_dist['avg_fid_raw'] - avg_by_dist['std_fid_raw'],
                          avg_by_dist['avg_fid_raw'] + avg_by_dist['std_fid_raw'], 
                          color=error_colors[error_combo], alpha=0.1)
            
            ax.fill_between(avg_by_dist['distance'], 
                          avg_by_dist['avg_fid_adj'] - avg_by_dist['std_fid_adj'],
                          avg_by_dist['avg_fid_adj'] + avg_by_dist['std_fid_adj'], 
                          color=error_colors[error_combo], alpha=0.1)
            
            # Track for legend
            error_combo_name = ERROR_COMBO_MAP.get(error_combo, error_combo)
            if error_combo_name not in error_lines:
                error_lines[error_combo_name] = line_raw
            
            method_name = METHOD_DISPLAY_MAP.get(qec_method, qec_method)
            if method_name not in method_markers:
                method_markers[method_name] = marker_map[qec_method]
    
    # Set up plot appearance
    ax.set_title("Fidelity vs. Distance by Error Combination", fontsize=16)
    ax.set_xlabel("Distance (km)", fontsize=14)
    ax.set_ylabel("Average Fidelity", fontsize=14)
    ax.set_ylim(0, 1.05)
    ax.set_xlim(left=0)
    ax.grid(True, linestyle=':', alpha=0.7)
    
    # Create a better legend
    from matplotlib.lines import Line2D
    
    # Create custom legend elements
    legend_elements = []
    
    # Add error types
    for name, line in error_lines.items():
        legend_elements.append(Line2D([0], [0], color=line.get_color(), lw=2, label=name))
    
    # Add line styles
    legend_elements.append(Line2D([0], [0], color='gray', lw=2, linestyle='-', label='Raw Fidelity'))
    legend_elements.append(Line2D([0], [0], color='gray', lw=2, linestyle='--', label='Adj. Fidelity'))
    
    # Add QEC methods
    for name, marker in method_markers.items():
        legend_elements.append(Line2D([0], [0], marker=marker, color='gray', linestyle='None',
                                     markersize=8, label=name))
    
    ax.legend(handles=legend_elements, loc='best', fontsize=10, framealpha=0.9,
             bbox_to_anchor=(1.02, 1), borderaxespad=0)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "fidelity_vs_distance_all.png"), dpi=300)
    plt.close()

# --- Plot 4: QEC Performance (Raw & Adjusted) vs Error Combination (Faceted by Error Combo) ---
def plot_qec_performance_faceted_by_error(df, plots_subdir="3_qec_performance_faceted_by_error"):
    """
    Plots raw and adjusted fidelity vs. QEC method, faceted by error combination.
    Includes std dev error bars representing variation across parameters/distances.
    """
    print("\n--- Plot 3: QEC Performance (Faceted by Error Combo) ---")
    plot_dir = os.path.join(PLOTS_DIR, plots_subdir)
    os.makedirs(plot_dir, exist_ok=True)

    qec_methods = sorted(df['qec_method'].unique())
    error_combos = sorted(df['error_combo'].unique())
    # Use METHOD_DISPLAY_MAP defined globally

    # Determine grid size
    n_combos = len(error_combos)
    n_cols = 2
    n_rows = (n_combos + n_cols - 1) // n_cols

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 5 * n_rows), sharex=True, sharey=True, squeeze=False)
    axs_flat = axs.flatten()

    for i, error_combo in enumerate(error_combos):
        ax = axs_flat[i]
        combo_subset = df[df['error_combo'] == error_combo]

        # Aggregate fidelities for this error combo
        agg_data = combo_subset.groupby('qec_method').agg(
            mean_raw=('avg_fidelity_raw', 'mean'),
            std_raw=('avg_fidelity_raw', 'std'),
            mean_adj=('avg_fidelity_adj', 'mean'),
            std_adj=('avg_fidelity_adj', 'std')
        ).reindex(qec_methods) # Ensure all methods are present and ordered

        if agg_data.empty:
             ax.text(0.5, 0.5, 'No Data', ha='center', va='center', transform=ax.transAxes)
             ax.set_title(error_combo)
             continue

        n_methods_local = len(agg_data.index)
        index = np.arange(n_methods_local)
        bar_width = 0.35

        # Plot bars
        bars_raw = ax.bar(index - bar_width/2, agg_data['mean_raw'], bar_width,
                          yerr=agg_data['std_raw'], label=r'Raw Fidelity',
                          capsize=3, alpha=0.85, edgecolor='black', linewidth=0.5)
        bars_adj = ax.bar(index + bar_width/2, agg_data['mean_adj'], bar_width,
                          yerr=agg_data['std_adj'], label=r'Loss-Adjusted Fidelity',
                          capsize=3, alpha=0.85, edgecolor='black', linewidth=0.5)

        ax.set_title(error_combo) # Use mapped short name
        ax.set_xticks(index)
        # Use METHOD_DISPLAY_MAP for labels
        ax.set_xticklabels([METHOD_DISPLAY_MAP.get(m, m) for m in agg_data.index], rotation=30, ha='right')
        ax.set_ylim(0, 1.05)
        ax.grid(axis='y', linestyle=':', alpha=0.7)
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))

        if i % n_cols == 0: # Only first column
             ax.set_ylabel(r'Average Fidelity $F$')
        if i == 0: # Add legend to the first plot
            ax.legend(loc='best', fontsize=9)

    # Hide unused subplots
    for j in range(i + 1, n_rows * n_cols):
        axs_flat[j].set_visible(False)

    fig.suptitle(r'QEC Method Performance Across Error Combinations', fontsize=18, y=1.02)
    plt.tight_layout(rect=[0, 0.03, 1, 1]) # Adjust rect
    plt.savefig(os.path.join(plot_dir, "qec_performance_faceted_by_error.png"))
    plt.close()

# --- Plot 5: Loss Ratio vs Distance (Combined) ---
def plot_loss_vs_distance_combined(df, plots_subdir="4_loss_vs_distance_combined"):
    """Creates a single plot showing loss ratio vs distance for all error combinations."""
    print("\n--- Plotting Loss Ratio vs. Distance (Combined) ---")
    plot_dir = os.path.join(PLOTS_DIR, plots_subdir)
    os.makedirs(plot_dir, exist_ok=True)
    
    # Filter out too-few points
    distances = sorted(df['distance'].unique())
    if len(distances) <= 1:
        print("  Not enough distance points to create plot.")
        return
    
    # Create a single plot
    plt.figure(figsize=(12, 8))
    ax = plt.gca()
    
    # Set up color scheme by error combination
    error_combos = sorted(df['error_combo'].unique())
    colors = plt.cm.viridis(np.linspace(0, 0.85, len(error_combos)))
    error_colors = {combo: colors[i] for i, combo in enumerate(error_combos)}
    
    # Set up marker scheme by QEC method
    qec_methods = sorted(df['qec_method'].unique())
    markers = ['o', 's', '^', 'd', '*']
    marker_map = {qec: markers[i % len(markers)] for i, qec in enumerate(qec_methods)}
    
    # Track lines for legend
    error_lines = {}
    method_markers = {}
    
    # Plot each error combination and QEC method
    for error_combo in error_combos:
        for qec_method in qec_methods:
            subset = df[(df['error_combo'] == error_combo) & 
                        (df['qec_method'] == qec_method)]
            
            if subset.empty:
                continue
                
            # Average across other parameters for each distance
            avg_by_dist = subset.groupby('distance').agg(
                avg_loss=('loss_ratio', 'mean'),
                std_loss=('loss_ratio', 'std')
            ).reset_index()
            
            if avg_by_dist.empty or len(avg_by_dist) < 2:
                continue
                
            # Plot loss ratio
            line, = ax.plot(avg_by_dist['distance'], avg_by_dist['avg_loss'],
                          marker=marker_map[qec_method], linestyle='-', linewidth=2,
                          color=error_colors[error_combo], markersize=8,
                          label=f"{ERROR_COMBO_MAP.get(error_combo, error_combo)} - {METHOD_DISPLAY_MAP.get(qec_method, qec_method)}")
            
            # Add error bars
            ax.fill_between(avg_by_dist['distance'], 
                          avg_by_dist['avg_loss'] - avg_by_dist['std_loss'],
                          avg_by_dist['avg_loss'] + avg_by_dist['std_loss'], 
                          color=error_colors[error_combo], alpha=0.1)
            
            # Track for legend
            error_combo_name = ERROR_COMBO_MAP.get(error_combo, error_combo)
            if error_combo_name not in error_lines:
                error_lines[error_combo_name] = line
            
            method_name = METHOD_DISPLAY_MAP.get(qec_method, qec_method)
            if method_name not in method_markers:
                method_markers[method_name] = marker_map[qec_method]
    
    # Set up plot appearance
    ax.set_title("Loss Ratio vs. Distance by Error Combination", fontsize=16)
    ax.set_xlabel("Distance (km)", fontsize=14)
    ax.set_ylabel("Average Loss Ratio", fontsize=14)
    ax.set_ylim(0, 1.05)
    ax.set_xlim(left=0)
    ax.grid(True, linestyle=':', alpha=0.7)
    
    # Create a better legend
    from matplotlib.lines import Line2D
    
    # Create custom legend elements
    legend_elements = []
    
    # Add error types
    for name, line in error_lines.items():
        legend_elements.append(Line2D([0], [0], color=line.get_color(), lw=2, label=name))
    
    # Add QEC methods
    for name, marker in method_markers.items():
        legend_elements.append(Line2D([0], [0], marker=marker, color='gray', linestyle='None',
                                     markersize=8, label=name))
    
    ax.legend(handles=legend_elements, loc='best', fontsize=10, framealpha=0.9,
             bbox_to_anchor=(1.02, 1), borderaxespad=0)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "loss_vs_distance_all.png"), dpi=300)
    plt.close()

# --- Add new function for fiber parameter analysis ---
def plot_fiber_params_analysis(df, plots_subdir="5_fiber_params_analysis"):
    """
    Creates plots analyzing the effect of fiber loss parameters:
    1. Fidelity vs p_loss_init
    2. Fidelity vs p_loss_length
    """
    print("\n--- Plotting Fidelity vs. Fiber Loss Parameters ---")
    plot_dir = os.path.join(PLOTS_DIR, plots_subdir)
    os.makedirs(plot_dir, exist_ok=True)
    
    # Only include fiber loss experiments
    fiber_df = df[df['error_combo'] == 'fibre_loss'].copy()
    if fiber_df.empty:
        print("  No fiber loss data found for parameter analysis.")
        return
        
    # Extract parameters
    fiber_df['p_loss_init'] = fiber_df['error_params'].apply(
        lambda p: p.get('p_loss_init', None) if isinstance(p, dict) else None)
    fiber_df['p_loss_length'] = fiber_df['error_params'].apply(
        lambda p: p.get('p_loss_length', None) if isinstance(p, dict) else None)
    
    # Drop rows with missing parameters
    fiber_df = fiber_df.dropna(subset=['p_loss_init', 'p_loss_length'])
    
    if fiber_df.empty:
        print("  Unable to extract fiber parameters from data.")
        return
    
    # --- 1. Plot Fidelity vs p_loss_init ---
    print("  Plotting fidelity vs p_loss_init...")
    
    # Get unique values for faceting
    p_lens = sorted(fiber_df['p_loss_length'].unique())
    distances = sorted(fiber_df['distance'].unique())
    
    if len(p_lens) > 1 and len(distances) > 0:
        plt.figure(figsize=(12, 8))
        ax = plt.gca()
        
        # Set up color and marker schemes
        colors = plt.cm.viridis(np.linspace(0, 0.85, len(distances)))
        dist_colors = {d: colors[i] for i, d in enumerate(distances)}
        
        qec_methods = sorted(fiber_df['qec_method'].unique())
        markers = ['o', 's', '^', 'd', '*']
        qec_markers = {qec: markers[i % len(markers)] for i, qec in enumerate(qec_methods)}
        
        linestyles = ['-', '--', '-.', ':']
        p_len_lines = {p_len: linestyles[i % len(linestyles)] for i, p_len in enumerate(p_lens)}
        
        # Track for legend
        dist_handles = {}
        method_handles = {}
        p_len_handles = {}
        
        # For each p_loss_length value
        for p_len in p_lens:
            # For each distance
            for dist in distances:
                # For each QEC method
                for qec_method in qec_methods:
                    subset = fiber_df[(fiber_df['p_loss_length'] == p_len) & 
                                      (fiber_df['distance'] == dist) & 
                                      (fiber_df['qec_method'] == qec_method)]
                    
                    if len(subset) < 2:
                        continue
                    
                    # Group by p_loss_init
                    grouped = subset.groupby('p_loss_init').agg(
                        avg_fid_raw=('avg_fidelity_raw', 'mean'),
                        std_fid_raw=('avg_fidelity_raw', 'std'),
                        avg_fid_adj=('avg_fidelity_adj', 'mean'),
                        std_fid_adj=('avg_fidelity_adj', 'std')
                    ).reset_index()
                    
                    # Sort by p_loss_init
                    grouped = grouped.sort_values('p_loss_init')
                    
                    if len(grouped) < 2:
                        continue
                    
                    # Plot raw fidelity
                    line_raw, = ax.plot(grouped['p_loss_init'], grouped['avg_fid_raw'],
                                      marker=qec_markers[qec_method], 
                                      linestyle=p_len_lines[p_len],
                                      color=dist_colors[dist], 
                                      markersize=8, linewidth=2,
                                      label=f"Dist={dist}, p_len={p_len}, {METHOD_DISPLAY_MAP.get(qec_method, qec_method)} (Raw)")
                    
                    # Plot adjusted fidelity (slightly transparent)
                    line_adj, = ax.plot(grouped['p_loss_init'], grouped['avg_fid_adj'],
                                      marker=qec_markers[qec_method], 
                                      linestyle=p_len_lines[p_len],
                                      color=dist_colors[dist], 
                                      markersize=8, linewidth=2, alpha=0.5,
                                      label=f"Dist={dist}, p_len={p_len}, {METHOD_DISPLAY_MAP.get(qec_method, qec_method)} (Adj.)")
                    
                    # Track for legend
                    dist_handles[dist] = dist_colors[dist]
                    method_handles[METHOD_DISPLAY_MAP.get(qec_method, qec_method)] = qec_markers[qec_method]
                    p_len_handles[p_len] = p_len_lines[p_len]
        
        # Set up plot appearance
        ax.set_title(r"Fidelity vs. Initial Loss Probability ($p_{loss\_init}$)", fontsize=16)
        ax.set_xlabel(r"Initial Loss Probability ($p_{loss\_init}$)", fontsize=14)
        ax.set_ylabel("Average Fidelity", fontsize=14)
        ax.set_ylim(0, 1.05)
        ax.grid(True, linestyle=':', alpha=0.7)
        
        # Create multi-part legend
        from matplotlib.lines import Line2D
        
        legend_elements = []
        
        # Distance legend elements
        for dist, color in dist_handles.items():
            legend_elements.append(Line2D([0], [0], color=color, lw=2, 
                                        label=f"Distance: {dist} km"))
        
        # Line style legend elements (p_loss_length)
        for p_len, ls in p_len_handles.items():
            legend_elements.append(Line2D([0], [0], color='gray', linestyle=ls, lw=2,
                                        label=f"Loss/km: {p_len} dB/km"))
        
        # QEC method legend elements
        for method, marker in method_handles.items():
            legend_elements.append(Line2D([0], [0], color='gray', marker=marker, linestyle='None',
                                        markersize=8, label=f"QEC: {method}"))
        
        # Fidelity type legend elements
        legend_elements.append(Line2D([0], [0], color='gray', lw=2, alpha=1.0,
                                    label='Raw Fidelity'))
        legend_elements.append(Line2D([0], [0], color='gray', lw=2, alpha=0.5,
                                    label='Adjusted Fidelity'))
        
        ax.legend(handles=legend_elements, loc='upper right', fontsize=9, 
                 framealpha=0.9, ncol=2)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, "fidelity_vs_p_loss_init.png"), dpi=300)
        plt.close()
    
    # --- 2. Plot Fidelity vs p_loss_length ---
    print("  Plotting fidelity vs p_loss_length...")
    
    # Get unique values for faceting
    p_inits = sorted(fiber_df['p_loss_init'].unique())
    
    if len(p_inits) > 1 and len(distances) > 0:
        plt.figure(figsize=(12, 8))
        ax = plt.gca()
        
        # Set up color and marker schemes
        colors = plt.cm.viridis(np.linspace(0, 0.85, len(distances)))
        dist_colors = {d: colors[i] for i, d in enumerate(distances)}
        
        qec_methods = sorted(fiber_df['qec_method'].unique())
        markers = ['o', 's', '^', 'd', '*']
        qec_markers = {qec: markers[i % len(markers)] for i, qec in enumerate(qec_methods)}
        
        linestyles = ['-', '--', '-.', ':']
        p_init_lines = {p_init: linestyles[i % len(linestyles)] for i, p_init in enumerate(p_inits)}
        
        # Track for legend
        dist_handles = {}
        method_handles = {}
        p_init_handles = {}
        
        # For each p_loss_init value
        for p_init in p_inits:
            # For each distance
            for dist in distances:
                # For each QEC method
                for qec_method in qec_methods:
                    subset = fiber_df[(fiber_df['p_loss_init'] == p_init) & 
                                      (fiber_df['distance'] == dist) & 
                                      (fiber_df['qec_method'] == qec_method)]
                    
                    if len(subset) < 2:
                        continue
                    
                    # Group by p_loss_length
                    grouped = subset.groupby('p_loss_length').agg(
                        avg_fid_raw=('avg_fidelity_raw', 'mean'),
                        std_fid_raw=('avg_fidelity_raw', 'std'),
                        avg_fid_adj=('avg_fidelity_adj', 'mean'),
                        std_fid_adj=('avg_fidelity_adj', 'std')
                    ).reset_index()
                    
                    # Sort by p_loss_length
                    grouped = grouped.sort_values('p_loss_length')
                    
                    if len(grouped) < 2:
                        continue
                    
                    # Plot raw fidelity
                    line_raw, = ax.plot(grouped['p_loss_length'], grouped['avg_fid_raw'],
                                      marker=qec_markers[qec_method], 
                                      linestyle=p_init_lines[p_init],
                                      color=dist_colors[dist], 
                                      markersize=8, linewidth=2,
                                      label=f"Dist={dist}, p_init={p_init}, {METHOD_DISPLAY_MAP.get(qec_method, qec_method)} (Raw)")
                    
                    # Plot adjusted fidelity (slightly transparent)
                    line_adj, = ax.plot(grouped['p_loss_length'], grouped['avg_fid_adj'],
                                      marker=qec_markers[qec_method], 
                                      linestyle=p_init_lines[p_init],
                                      color=dist_colors[dist], 
                                      markersize=8, linewidth=2, alpha=0.5,
                                      label=f"Dist={dist}, p_init={p_init}, {METHOD_DISPLAY_MAP.get(qec_method, qec_method)} (Adj.)")
                    
                    # Track for legend
                    dist_handles[dist] = dist_colors[dist]
                    method_handles[METHOD_DISPLAY_MAP.get(qec_method, qec_method)] = qec_markers[qec_method]
                    p_init_handles[p_init] = p_init_lines[p_init]
        
        # Set up plot appearance
        ax.set_title(r"Fidelity vs. Loss per Length ($p_{loss\_length}$)", fontsize=16)
        ax.set_xlabel(r"Loss per Length (dB/km)", fontsize=14)
        ax.set_ylabel("Average Fidelity", fontsize=14)
        ax.set_ylim(0, 1.05)
        ax.grid(True, linestyle=':', alpha=0.7)
        
        # Create multi-part legend
        from matplotlib.lines import Line2D
        
        legend_elements = []
        
        # Distance legend elements
        for dist, color in dist_handles.items():
            legend_elements.append(Line2D([0], [0], color=color, lw=2, 
                                        label=f"Distance: {dist} km"))
        
        # Line style legend elements (p_loss_init)
        for p_init, ls in p_init_handles.items():
            legend_elements.append(Line2D([0], [0], color='gray', linestyle=ls, lw=2,
                                        label=f"Init Loss: {p_init}"))
        
        # QEC method legend elements
        for method, marker in method_handles.items():
            legend_elements.append(Line2D([0], [0], color='gray', marker=marker, linestyle='None',
                                        markersize=8, label=f"QEC: {method}"))
        
        # Fidelity type legend elements
        legend_elements.append(Line2D([0], [0], color='gray', lw=2, alpha=1.0,
                                    label='Raw Fidelity'))
        legend_elements.append(Line2D([0], [0], color='gray', lw=2, alpha=0.5,
                                    label='Adjusted Fidelity'))
        
        ax.legend(handles=legend_elements, loc='upper right', fontsize=9, 
                 framealpha=0.9, ncol=2)
        
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, "fidelity_vs_p_loss_length.png"), dpi=300)
        plt.close()



In [21]:
# --- Main Execution ---
if __name__ == "__main__":
    configure_plots()
    
    # Load the data (no need for full data for these plots)
    df_distance_metrics = load_distance_qec_data()
    
    if not df_distance_metrics:
        print("No distance data loaded, cannot perform analysis.")
    else:
        print(f"Loaded {len(df_distance_metrics)} experiment configurations.")
        if len(df_distance_metrics) > 0:
            print("\nExample loaded distance data point:")
            print(df_distance_metrics[0])
            
            # Call all the analysis functions
            plot_fidelity_by_error_combo(df_distance_metrics)
            plot_fidelity_vs_distance_combined(df_distance_metrics)
            plot_loss_vs_distance_combined(df_distance_metrics)
            plot_qec_performance_faceted_by_error(df_distance_metrics)
            plot_fiber_params_analysis(df_distance_metrics)
            
            print(f"\nAnalysis complete. Plots saved to: {PLOTS_DIR}") 

Loading distance experiment data from: distance_qec_fiber_loss
Finished loading. Loaded 149 experiment configurations from metadata.
Loaded 149 experiment configurations.

Example loaded distance data point:
{'initial_state': '+', 'qec_method': 'none', 'distance': 10.0, 'error_combo': 'fibre_loss_plus_phase_damping', 'error_params': {'p_loss_init': 0.05, 'p_loss_length': 0.16, 'sec_gamma': 0.01}, 'param_string': 'p_loss_init=0.05, p_loss_length=0.16, sec_gamma=0.01', 'avg_fidelity_raw': 0.997005988023952, 'std_fidelity_raw': 0.0547175655164582, 'avg_fidelity_adj': 0.5014880418803112, 'loss_ratio': 0.49700598802395207, 'loss_count': 166, 'total_runs': 334, 'filename_base': '+_none_d10.0_p_loss_init_0.05_p_loss_length_0.16_sec_gamma_0.01.csv'}

--- Plot 1: Fidelity vs Error Combination ---


AttributeError: 'list' object has no attribute 'groupby'