<a href="https://colab.research.google.com/github/EnochYounceSAIC/FiberOptics/blob/main/weeksix/gui.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import tkinter as tk
from tkinter import ttk
import numpy as np # Ensure numpy is imported as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk # Import NavigationToolbar2Tk
import scipy.special # Ensure scipy.special is imported
from matplotlib.figure import Figure


# Define the simulation parameters with default values
default_params = {
    'length': 3000,
    'span': 100,
    'EDFA_total_power': 15, # EDFA target total output power in dBm
    'EDFA_noise': 4.7, # EDFA Noise Figure in dB
    'OSNR_initial': 33.5, # Initial OSNR at transmitter output in dB (in reference bandwidth)
    # Updated span_layout format: List of tuples (fiber_type, length_km) for one span
    'span_layout': [('smf28', 33), ('leaf', 67)], # Example: 33km SMF28, 67km LEAF per span
    'channels': 70,
    'lambda_light': 1540e-9, # Starting wavelength in meters
    'symbol_rate': 10e9, # Symbol rate in Hz
    'discrete_loss_per_event_dB': 0.2, # Discrete loss per span section (e.g., connector loss)
    'snr_bandwidth_nm': 0.4e-9, # Bandwidth for SNR calculation in meters (e.g., receiver filter bandwidth)
    'reference_bandwidth_nm': 0.1e-9, # Reference bandwidth for OSNR definition in meters (0.1 nm)
    'cable_types': {
        'smf28': {'attenuation_per_km': 0.18, 'dispersion': 18, 'dispersion_slope': 0.092, 'attenuation_slope_quadratic': 2.5e-4}, # dispersion in ps/nm/km, attenuation_slope_quadratic in dB/km/nm^2
        'leaf': {'attenuation_per_km': 0.22, 'dispersion': -4, 'dispersion_slope': -0.12, 'attenuation_slope_quadratic': 2.5e-4}, # dispersion in ps/nm/km, attenuation_slope_quadratic in dB/km/nm^2
        'fiber_type_3': {'attenuation_per_km': 0.20, 'dispersion': 8, 'dispersion_slope': 0.06, 'attenuation_slope_quadratic': 2.0e-4}, # New type
        'fiber_type_4': {'attenuation_per_km': 0.25, 'dispersion': -2, 'dispersion_slope': -0.10, 'attenuation_slope_quadratic': 3.0e-4}, # New type
        'fiber_type_5': {'attenuation_per_km': 0.17, 'dispersion': 20, 'dispersion_slope': 0.08, 'attenuation_slope_quadratic': 2.2e-4}  # New type
    }
}

class OpticalFiberSimulatorGUI:
    def __init__(self, root):
        self.root = root
        root.title("Optical Fiber Simulation")

        self.notebook = ttk.Notebook(root)
        self.notebook.pack(pady=10, padx=10, expand=True, fill="both")

        # Create tabs
        self.input_tab = ttk.Frame(self.notebook)
        self.dispersion_tab = ttk.Frame(self.notebook)
        self.power_tab = ttk.Frame(self.notebook)
        self.osnr_tab = ttk.Frame(self.notebook)
        self.ber_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.input_tab, text='Parameters')
        self.notebook.add(self.dispersion_tab, text='Accumulated Dispersion')
        self.notebook.add(self.power_tab, text='Signal Power')
        self.notebook.add(self.osnr_tab, text='OSNR')
        self.ber_tab.pack(fill='both', expand=True) # Ensure BER tab expands
        self.notebook.add(self.ber_tab, text='BER')

        # Store plot canvases and figures - Initialize figures here
        self.dispersion_fig = Figure(figsize=(10, 5))
        self.power_fig = Figure(figsize=(10, 5))
        self.osnr_distance_fig = Figure(figsize=(10, 5))
        self.osnr_wavelength_fig = Figure(figsize=(10, 5))
        self.ber_fig = Figure(figsize=(10, 5))

        self.dispersion_canvas = None
        self.power_canvas = None
        self.osnr_distance_canvas = None
        self.osnr_wavelength_canvas = None
        self.ber_canvas = None

        # Add a status label
        self.status_label = ttk.Label(root, text="", foreground="red")
        self.status_label.pack(pady=5)


        self.create_input_widgets()
        self.create_plot_widgets() # create_plot_widgets now creates canvases and packs widgets

        # Store simulation results
        self.simulation_results = {}

        # Run initial simulation and display plots
        self.run_simulation()

    def create_input_widgets(self):
        # Create labels and entry fields for parameters
        self.param_entries = {}
        row = 0

        for param, default_value in default_params.items():
            if param in ['span_layout', 'cable_types']: # Handle these separately or provide a simplified input
                 continue

            label = ttk.Label(self.input_tab, text=f"{param}:")
            label.grid(row=row, column=0, padx=5, pady=5, sticky="w")

            entry = ttk.Entry(self.input_tab)
            entry.insert(0, str(default_value))
            entry.grid(row=row, column=1, padx=5, pady=5, sticky="ew")
            self.param_entries[param] = entry
            row += 1

        # --- Add Span Layout Input ---
        # This is a simplified input for up to 3 segments per span
        span_layout_label = ttk.Label(self.input_tab, text="Span Layout (Fiber Type, Length km):")
        span_layout_label.grid(row=row, column=0, padx=5, pady=5, sticky="w")
        row += 1

        self.span_segment_inputs = []
        fiber_types = list(default_params['cable_types'].keys())

        for i in range(3): # Allow up to 3 segments
            segment_frame = ttk.Frame(self.input_tab)
            segment_frame.grid(row=row, column=0, columnspan=2, padx=5, pady=2, sticky="ew")

            # Dropdown for Fiber Type
            fiber_type_var = tk.StringVar(value=fiber_types[0] if fiber_types else "")
            fiber_type_dropdown = ttk.Combobox(segment_frame, textvariable=fiber_type_var, values=fiber_types, state="readonly", width=15)
            fiber_type_dropdown.pack(side=tk.LEFT, padx=2)
            if fiber_types:
                 # Pre-fill with default if available and matches a valid fiber type
                 if i < len(default_params['span_layout']) and default_params['span_layout'][i][0] in fiber_types:
                      fiber_type_dropdown.set(default_params['span_layout'][i][0])
                 else:
                      fiber_type_dropdown.current(0)


            # Entry for Segment Length
            length_entry = ttk.Entry(segment_frame, width=10)
            if i < len(default_params['span_layout']): # Pre-fill with default length if available
                length_entry.insert(0, str(default_params['span_layout'][i][1]))
            length_entry.pack(side=tk.LEFT, padx=2)

            ttk.Label(segment_frame, text="km").pack(side=tk.LEFT)

            self.span_segment_inputs.append((fiber_type_var, length_entry))
            row += 1


        # Add a button to run the simulation
        run_button = ttk.Button(self.input_tab, text="Run Simulation", command=self.run_simulation)
        run_button.grid(row=row, column=0, columnspan=2, pady=10)


    def create_plot_widgets(self):
        # Create and pack canvases and toolbars for each tab

        # Dispersion Tab
        self.dispersion_canvas = FigureCanvasTkAgg(self.dispersion_fig, master=self.dispersion_tab)
        self.dispersion_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        toolbar_dispersion = NavigationToolbar2Tk(self.dispersion_canvas, self.dispersion_tab)
        toolbar_dispersion.update()



        # Power Tab
        self.power_canvas = FigureCanvasTkAgg(self.power_fig, master=self.power_tab)
        self.power_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        toolbar_power = NavigationToolbar2Tk(self.power_canvas, self.power_tab)
        toolbar_power.update()



        # OSNR Tab
        # OSNR Distance Plot
        self.osnr_distance_canvas = FigureCanvasTkAgg(self.osnr_distance_fig, master=self.osnr_tab)
        self.osnr_distance_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        toolbar_osnr_distance = NavigationToolbar2Tk(self.osnr_distance_canvas, self.osnr_tab)
        toolbar_osnr_distance.update()


        # OSNR Wavelength Plot
        self.osnr_wavelength_canvas = FigureCanvasTkAgg(self.osnr_wavelength_fig, master=self.osnr_tab)
        self.osnr_wavelength_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        toolbar_osnr_wavelength = NavigationToolbar2Tk(self.osnr_wavelength_canvas, self.osnr_tab)
        toolbar_osnr_wavelength.update()


        # BER Tab
        self.ber_canvas = FigureCanvasTkAgg(self.ber_fig, master=self.ber_tab)
        self.ber_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        toolbar_ber = NavigationToolbar2Tk(self.ber_canvas, self.ber_tab)
        toolbar_ber.update()



    def run_simulation(self):
        self.status_label.config(text="") # Clear previous status messages

        # Get parameters from input fields
        try:
            length = int(self.param_entries['length'].get())
            span = int(self.param_entries['span'].get())
            edfa_total_power = float(self.param_entries['EDFA_total_power'].get())
            edfa_noise = float(self.param_entries['EDFA_noise'].get())
            osnr_initial = float(self.param_entries['OSNR_initial'].get())
            channels = int(self.param_entries['channels'].get())
            lambda_light = float(self.param_entries['lambda_light'].get()) # in meters
            symbol_rate = float(self.param_entries['symbol_rate'].get()) # in Hz
            discrete_loss_per_event_dB = float(self.param_entries['discrete_loss_per_event_dB'].get())
            snr_bandwidth_nm = float(self.param_entries['snr_bandwidth_nm'].get()) # in meters
            reference_bandwidth_nm = float(default_params['reference_bandwidth_nm']) # in meters

            # Get span layout from input fields
            span_layout = []
            total_segment_length = 0
            for fiber_type_var, length_entry in self.span_segment_inputs:
                 try:
                     segment_length = int(length_entry.get())
                     fiber_type = fiber_type_var.get()
                     if segment_length > 0:
                         span_layout.append((fiber_type, segment_length))
                         total_segment_length += segment_length
                 except ValueError:
                     pass # Ignore empty or invalid entries

            if not span_layout:
                self.status_label.config(text="Error: Span layout is empty or invalid.", foreground="red")
                print("Error: Span layout is empty or invalid.")
                # Use default and proceed, or return? Let's return for now to force valid input.
                # self.status_label.config(text="Span layout empty/invalid. Using default.", foreground="orange")
                # span_layout = default_params['span_layout']
                # total_segment_length = sum(seg_length for _, seg_length in span_layout)
                return # Stop simulation if input is invalid


            # Validate that total segment length matches the span parameter
            if total_segment_length != span:
                 self.status_label.config(text=f"Error: Total span segment length ({total_segment_length} km) does not match 'span' parameter ({span} km).", foreground="red")
                 print(f"Error: Total span segment length ({total_segment_length} km) does not match 'span' parameter ({span} km).")
                 return # Stop simulation if span length mismatch


            cable_types = default_params['cable_types']

            # Validate fiber types in span_layout exist in cable_types
            for fiber_type, _ in span_layout:
                 if fiber_type not in cable_types:
                      self.status_label.config(text=f"Error: Fiber type '{fiber_type}' in span layout not found in defined cable types.", foreground="red")
                      print(f"Error: Fiber type '{fiber_type}' in span layout not found in defined cable types.")
                      return # Stop if an unknown fiber type is specified


            # Ensure length is at least 1 for simulation loop
            if length < 1:
                 self.status_label.config(text="Error: Link length must be at least 1 km.", foreground="red")
                 print("Error: Link length must be at least 1 km.")
                 return


        except ValueError as e:
            self.status_label.config(text=f"Error reading parameters: {e}", foreground="red")
            print(f"Error reading parameters: {e}")
            # Ensure simulation_results has necessary keys even on error for update_plots
            self.simulation_results = {
                 'distance': np.array([0]), # Default to a single point on error
                 'channeled_lambdas_nm': np.array([]),
                 'accumulated_dispersion': np.array([[]]),
                 'signal_power_dBm': np.array([[]]),
                 'osnr_dB': np.array([[]]),
                 'snr_dB': np.array([[]]),
                 'ber_results': {},
                 'channels_to_plot': [] # Initialize as empty list
            }
            self.update_plots(self.simulation_results)
            return # Stop simulation if parameters is invalid


        # --- Simulation Logic ---

        channeled_lambdas = np.zeros(channels)
        # Calculate channeled lambdas in meters
        for i in range(channels):
            channeled_lambdas[i] = lambda_light + i * 0.4e-9 # Assuming 400 GHz channel spacing (0.4 nm = 0.4e-9 m)


        channeled_lambdas_nm = channeled_lambdas * 1e9 # Store in nm for plotting labels


        # Prepare a list of fiber types for each kilometer across the entire link
        link_fiber_types = []
        span_length_from_layout = sum(seg_length for _, seg_length in span_layout) # Sum of segment lengths in the span layout

        # This check is now redundant due to the span == total_segment_length check above, but harmless.
        if span_length_from_layout <= 0:
             self.status_label.config(text="Error: Calculated total span length is 0 or negative. Cannot simulate.", foreground="red")
             print("Error: Calculated total span length is 0 or negative. Cannot simulate.")
             return

        # Create the fiber type list for the entire link based on repeating the span layout
        # The link consists of 'length' km. EDFAs are at 'span', '2*span', ... km.
        # Each section between EDFAs (of length 'span') uses the 'span_layout'.
        # Since we validated that sum(span_layout lengths) == span, we can just repeat the span_layout.
        link_fiber_types = []
        for km in range(length):
             km_within_span = km % span # Kilometer within the current span (0 to span-1)
             current_pos_in_span_layout = 0 # Reset for each span
             current_fiber_type = None     # Reset for each span
             for fiber_type, seg_length in span_layout:
                  if km_within_span < current_pos_in_span_layout + seg_length:
                       current_fiber_type = fiber_type
                       break
                  current_pos_in_span_layout += seg_length # Increment current_pos_in_span_layout by the segment length


             link_fiber_types.append(current_fiber_type)


        # Add a check if link_fiber_types construction resulted in None (shouldn't happen if span_layout is valid)
        if None in link_fiber_types:
             self.status_label.config(text="Internal Error: Could not determine fiber type for all kilometers.", foreground="red")
             print("Internal Error: Could not determine fiber type for all kilometers.")
             return

        # Ensure link_fiber_types has the correct length 'length'
        if len(link_fiber_types) != length:
             self.status_label.config(text=f"Internal Error: Generated link fiber types length ({len(link_fiber_types)} km) does not match link length ({length} km).", foreground="red")
             print(f"Internal Error: Generated link fiber types length ({len(link_fiber_types)} km) does not match link length ({length} km).")
             return


        # 1. Accumulated Dispersion Calculation
        # Calculate accumulated dispersion for the entire link using the dedicated function
        calculated_accumulated_dispersion = self.chrom_dispersion_total_with_slope(
            length, span_layout, cable_types, channeled_lambdas
        )

        # 2. Attenuation and Power Calculation
        multi_channel_signal_power_dBm = np.zeros((channels, length + 1))
        # Initialize multi_channel_noise_power_mW with the correct shape
        multi_channel_noise_power_mW = np.zeros((channels, length + 1))

        # Calculate initial channel powers using the pre-emphasis function
        # Assuming target_total_initial_power_dBm is the same as EDFA_total_power for initial state
        target_total_initial_power_dBm = edfa_total_power

        # Calculate the attenuation profile over a single span to use for EDFA gain weights
        # Need to recalculate span attenuation based on the specific span layout
        span_attenuation_dB_per_channel = np.zeros(channels)
        for fiber_type, seg_length in span_layout:
             current_cable_type_params = cable_types[fiber_type]
             for j in range(channels):
                  # calculate_wavelength_dependent_attenuation now takes wavelength in meters
                  attenuation_per_km_at_lambda = self.calculate_wavelength_dependent_attenuation(
                     current_cable_type_params, channeled_lambdas[j] # channeled_lambdas in meters
                 )
                  span_attenuation_dB_per_channel[j] += attenuation_per_km_at_lambda * seg_length # Multiply by segment length in km


        average_span_attenuation_dB = np.mean(span_attenuation_dB_per_channel)
        # Gain weights represent the target gain shape relative to the average gain
        # This is used to distribute the total EDFA output power across channels
        # For a flat output power after the EDFA, the EDFA gain shape should compensate for span loss shape.
        # So, gain_weights should ideally be related to span_attenuation_dB_per_channel
        # Let's use the span attenuation itself as the target gain shape in dB for simplicity
        edfa_gain_shape_target_dB = span_attenuation_dB_per_channel


        channel_power = self.calculate_preemphasis_powers(channels, edfa_gain_shape_target_dB, target_total_initial_power_dBm) # Use gain shape for pre-emphasis

        multi_channel_signal_power_dBm[:, 0] = channel_power
        # Calculate initial noise power based on initial signal power and initial OSNR
        # OSNR (dB) = Signal Power (dBm) - Noise Power (dBm) in reference bandwidth
        # Noise Power (dBm/0.1nm) = Initial Signal Power (dBm) - Initial OSNR (dB)
        initial_noise_power_ref_bw_dBm = multi_channel_signal_power_dBm[:, 0] - osnr_initial
        # Propagate noise power in mW. Initial noise power in the simulation bandwidth (snr_bandwidth_nm)
        initial_noise_power_snr_bw_dBm = initial_noise_power_ref_bw_dBm + 10 * np.log10(snr_bandwidth_nm / reference_bandwidth_nm)
        multi_channel_noise_power_mW[:, 0] = 10**(initial_noise_power_snr_bw_dBm / 10)


        multi_channel_signal_power_mW = 10**(multi_channel_signal_power_dBm / 10)


        # Discrete losses at the end of each span section (excluding the final link end)
        # For simplicity here, let's assume discrete loss only occurs *before* the EDFA at the end of each span.
        total_discrete_loss_per_span_dB = discrete_loss_per_event_dB # Assuming discrete_loss_per_event_dB is the total for one span section before EDFA


        for i in range(length): # Iterate through each kilometer
            # Get the fiber type for the current kilometer
            current_fiber_type_name = link_fiber_types[i]
            current_cable_type_params = cable_types[current_fiber_type_name]

            # Calculate attenuation per km at the current wavelength for the current fiber type
            attenuation_per_km_dB = np.zeros(channels)
            for j in range(channels):
                # calculate_wavelength_dependent_attenuation now takes wavelength in meters
                attenuation_per_km_dB[j] = self.calculate_wavelength_dependent_attenuation(
                    current_cable_type_params, channeled_lambdas[j] # channeled_lambdas in meters
                )

            # Apply attenuation for 1 km
            multi_channel_signal_power_dBm[:, i+1] = multi_channel_signal_power_dBm[:, i] - attenuation_per_km_dB
            # Noise power also attenuates. Noise power is in mW.
            multi_channel_noise_power_mW[:, i+1] = multi_channel_noise_power_mW[:, i] * 10**(-attenuation_per_km_dB / 10)


            multi_channel_signal_power_mW[:, i+1] = 10**(multi_channel_signal_power_dBm[:, i+1] / 10)


            # Check for EDFA at span boundary
            # EDFAs are located at span, 2*span, 3*span, ... km
            if (i + 1) % span == 0 and (i + 1) <= length:
                location = i + 1

                # Apply discrete loss just before the EDFA
                if total_discrete_loss_per_span_dB > 0:
                     multi_channel_signal_power_dBm[:, location] -= total_discrete_loss_per_span_dB
                     multi_channel_noise_power_mW[:, location] *= 10**(-total_discrete_loss_per_span_dB / 10) # Apply discrete loss to noise as well
                     # Update mW values after discrete loss
                     multi_channel_signal_power_mW[:, location] = 10**(multi_channel_signal_power_dBm[:, location] / 10)


                # Recalculate span attenuation for the EDFA gain calculation at this location
                # The span layout is defined at the input and assumed to be repeated.
                # So span_attenuation_dB_per_channel calculated earlier is the correct loss for *one* span.
                span_attenuation_at_location_dB_per_channel = np.zeros(channels)
                for fiber_type, seg_length in span_layout:
                    current_cable_type_params = cable_types[fiber_type]
                    for j in range(channels):
                         attenuation_per_km_at_lambda = self.calculate_wavelength_dependent_attenuation(
                            current_cable_type_params, channeled_lambdas[j]
                         )
                         span_attenuation_at_location_dB_per_channel[j] += attenuation_per_km_at_lambda * seg_length


                # Apply EDFA gain and add ASE noise
                self.const_power_EDFA_with_tilt_saturation(
                    multi_channel_signal_power_mW, # Pass mW signal
                    multi_channel_noise_power_mW, # Pass mW noise
                    edfa_total_power, # Target total output power
                    channeled_lambdas, # Wavelengths in meters
                    edfa_noise, # Noise Figure
                    location, # Just the location index
                    span_attenuation_at_location_dB_per_channel # Pass span attenuation per channel
                )

                # Apply gain equalizer after EDFA (optional, can be used for fine tuning)
                # The const_power_EDFA function now targets flat power, but this could still be used
                # for residual equalization if needed. Let's keep it for now.
                # Removing the gain equalizer for now to isolate the EDFA's impact on wavelength dependency.
                # self.apply_gain_equalizer(multi_channel_signal_power_mW, multi_channel_noise_power_mW, edfa_total_power, location)


                # Convert signal back to dBm after EDFA (and optional Equalizer)
                epsilon = 1e-18
                multi_channel_signal_power_dBm[:, location] = 10*np.log10(multi_channel_signal_power_mW[:, location] + epsilon)
                # Noise power is kept in mW for accumulation.

        # 3. OSNR Calculation
        epsilon_power = 1e-18
        # Calculate OSNR from mW values
        # OSNR (dB) = 10 * log10(Signal Power (mW) / Noise Power (mW)) in the reference bandwidth (0.1 nm)
        # Noise power in reference bandwidth = multi_channel_noise_power_mW * (reference_bandwidth_nm / snr_bandwidth_nm)
        noise_power_ref_bw_mW = multi_channel_noise_power_mW * (reference_bandwidth_nm / snr_bandwidth_nm)
        noise_power_ref_bw_mW = np.maximum(noise_power_ref_bw_mW, epsilon_power) # Ensure noise power is not zero or negative


        multi_channel_OSNR_dB = 10 * np.log10(multi_channel_signal_power_mW / noise_power_ref_bw_mW)


        # 4. SNR Calculation
        # SNR (dB) = Signal Power (mW in sim BW) / Noise Power (mW in sim BW) in the simulation bandwidth (snr_bandwidth_nm)
        # This is already calculated as multi_channel_signal_power_mW / multi_channel_noise_power_mW
        # The previous calculation of SNR from OSNR was incorrect. We should calculate SNR directly from the signal and noise power in the simulation bandwidth.
        # multi_channel_SNR_dB = multi_channel_OSNR_dB + bandwidth_scaling_dB # This was the incorrect line

        # Correct SNR calculation using signal and noise power in the simulation bandwidth
        multi_channel_SNR_dB = 10 * np.log10(multi_channel_signal_power_mW / multi_channel_noise_power_mW)


        # 5. BER Calculation (using final SNR)
        # Final SNR for BER calculation (assuming DSP handles dispersion)
        final_snr_db = multi_channel_SNR_dB[:, -1]




        modulation_formats = ['BPSK', 'QPSK', '8PSK', '16QAM', 'NRZ-OOK', 'RZ-OOK', 'DPSK']
        ber_results = {}
        for modulation in modulation_formats:
            # Use Q-factor based BER calculation (based on SNR)
            uncoded_ber, q_factor = self.calculate_ber(final_snr_db, modulation=modulation, pre_encoding=None) # calculate_ber returns BER and Q
            ber_results[modulation] = uncoded_ber # Store uncoded BER
            ber_results[f'{modulation}_q_factor'] = q_factor # Store Q-factor



        # Store results
        self.simulation_results = {
            'distance': np.arange(0, length + 1, 1),
            'channeled_lambdas_nm': channeled_lambdas_nm,
            'accumulated_dispersion': calculated_accumulated_dispersion, # Use the new variable name
            'signal_power_dBm': multi_channel_signal_power_dBm,
            'osnr_dB': multi_channel_OSNR_dB,
            'snr_dB': multi_channel_SNR_dB, # Store the original SNR
            'ber_results': ber_results, # Store uncoded, coded BERs and Q-factors
            'channels_to_plot': [0, channels // 4, channels // 2, channels * 3 // 4, channels - 1] # Example channels for dispersion
        }

        # Update plots
        self.update_plots(self.simulation_results)

        self.status_label.config(text="Simulation Complete", foreground="green")


    def update_plots(self, simulation_results):
        # Clear previous plots and create axes
        if self.dispersion_fig:
            self.dispersion_fig.clear()
            ax_dispersion = self.dispersion_fig.add_subplot(111)
            # Plot Accumulated Dispersion
            distance_km = simulation_results.get('distance', np.array([])) # Use .get with default empty array
            accumulated_dispersion = simulation_results.get('accumulated_dispersion', np.array([[]])) # Use .get with default empty array
            channels_to_plot_disp = simulation_results.get('channels_to_plot', []) # Use .get with default empty list
            channeled_lambdas_nm = simulation_results.get('channeled_lambdas_nm', np.array([])) # Use .get with default empty array

            # Add check to ensure channels_to_plot_disp is iterable and not empty
            if isinstance(channels_to_plot_disp, (list, np.ndarray)) and len(channels_to_plot_disp) > 0:
                 for channel_index in channels_to_plot_disp:
                     if accumulated_dispersion.shape[0] > channel_index and accumulated_dispersion.shape[1] == len(distance_km):
                         ax_dispersion.plot(distance_km, accumulated_dispersion[channel_index, :], label=f'Channel {channel_index} Dispersion (ps/nm)')
                     else:
                         # print(f"Warning: Could not plot dispersion for channel {channel_index} due to shape mismatch or invalid index.")
                         pass # Suppress repetitive warnings


            ax_dispersion.set_xlabel("Distance (km)")
            ax_dispersion.set_ylabel("Accumulated Dispersion (ps/nm)")
            ax_dispersion.set_title("Accumulated Chromatic Dispersion along Fiber")
            ax_dispersion.legend()
            ax_dispersion.grid(True)
            if self.dispersion_canvas:
                self.dispersion_canvas.draw()


        if self.power_fig:
            self.power_fig.clear()
            ax_power = self.power_fig.add_subplot(111)
            # Plot Signal Power
            signal_power_dBm = simulation_results.get('signal_power_dBm', np.array([[]])) # Use .get with default empty array
            channels_to_plot_power = [0, signal_power_dBm.shape[0] // 2, signal_power_dBm.shape[0] - 1] if signal_power_dBm.shape[0] > 0 else [] # Plot first, middle, and last channel

            if isinstance(channels_to_plot_power, (list, np.ndarray)) and len(channels_to_plot_power) > 0:
                distance_km = simulation_results.get('distance', np.array([]))
                for channel_index in channels_to_plot_power:
                     if signal_power_dBm.shape[0] > channel_index and signal_power_dBm.shape[1] == len(distance_km):
                          ax_power.plot(distance_km, signal_power_dBm[channel_index, :], label=f'Channel {channel_index}')
                     else:
                         # print(f"Warning: Could not plot power for channel {channel_index} due to shape mismatch or invalid index.")
                         pass # Suppress repetitive warnings

            ax_power.set_xlabel("Distance (km)")
            ax_power.set_ylabel("Signal Power (dBm)")
            ax_power.set_title("Signal Power over Fiber Length")
            ax_power.legend()
            ax_power.grid(True)
            if self.power_canvas:
                self.power_canvas.draw()

        if self.osnr_distance_fig and self.osnr_wavelength_fig:
            self.osnr_distance_fig.clear()
            self.osnr_wavelength_fig.clear()
            ax_osnr_distance = self.osnr_distance_fig.add_subplot(111)
            ax_osnr_wavelength = self.osnr_wavelength_fig.add_subplot(111)

            # Plot OSNR over Distance
            osnr_dB = simulation_results.get('osnr_dB', np.array([[]])) # Use .get with default empty array
            channels_to_plot_osnr_dist = [0, osnr_dB.shape[0] // 2, osnr_dB.shape[0] - 1] if osnr_dB.shape[0] > 0 else [] # Use the same channels as power plot

            if isinstance(channels_to_plot_osnr_dist, (list, np.ndarray)) and len(channels_to_plot_osnr_dist) > 0:
                distance_km = simulation_results.get('distance', np.array([]))
                for channel_index in channels_to_plot_osnr_dist:
                    if osnr_dB.shape[0] > channel_index and osnr_dB.shape[1] == len(distance_km):
                         ax_osnr_distance.plot(distance_km, osnr_dB[channel_index, :], label=f'Channel {channel_index}')
                    else:
                        # print(f"Warning: Could not plot OSNR over distance for channel {channel_index} due to shape mismatch or invalid index.")
                        pass # Suppress repetitive warnings


            ax_osnr_distance.set_xlabel("Distance (km)")
            ax_osnr_distance.set_ylabel("OSNR (dB)")
            ax_osnr_distance.set_title("OSNR over Fiber Length")
            ax_osnr_distance.grid(True)
            ax_osnr_distance.legend()
            if self.osnr_distance_canvas:
                self.osnr_distance_canvas.draw()


            # Plot OSNR across Wavelengths (at the end of the link)
            if osnr_dB.shape[0] > 0 and osnr_dB.shape[1] > 0:
                 channeled_lambdas_nm = simulation_results.get('channeled_lambdas_nm', np.array([]))
                 if len(channeled_lambdas_nm) == osnr_dB.shape[0]:
                    ax_osnr_wavelength.plot(channeled_lambdas_nm, osnr_dB[:, -1])
                    ax_osnr_wavelength.set_xlabel("Wavelength (nm)")
                    ax_osnr_wavelength.set_ylabel("OSNR (dB)")
                    ax_osnr_wavelength.set_title("Final OSNR across Channels")
                    ax_osnr_wavelength.grid(True)
                 else:
                    # print("Warning: Wavelengths data mismatch for OSNR across wavelengths plot.")
                    pass # Suppress warning
            else:
                # print("Warning: Could not plot OSNR across wavelengths due to empty OSNR data.")
                pass # Suppress warning
            if self.osnr_wavelength_canvas:
                self.osnr_wavelength_canvas.draw()


        if self.ber_fig:
            self.ber_fig.clear()
            ax_ber = self.ber_fig.add_subplot(111)
            # Plot BER across Wavelengths for different modulation formats
            ber_results = simulation_results.get('ber_results', {}) # Use .get with default empty dictionary
            modulation_formats = ['BPSK', 'QPSK', '8PSK', '16QAM', 'NRZ-OOK', 'RZ-OOK', 'DPSK']

            if ber_results: # Check if ber_results is not empty
                channeled_lambdas_nm = simulation_results.get('channeled_lambdas_nm', np.array([]))
                if len(channeled_lambdas_nm) > 0:
                    for modulation in modulation_formats:
                        uncoded_ber_key = modulation # Key for the uncoded BER
                        q_factor_key = f'{modulation}_q_factor' # Key for the Q-factor

                        uncoded_ber_data = ber_results.get(uncoded_ber_key, None)

                        q_factor_data = ber_results.get(q_factor_key, None) # Get Q-factor data

                        plot_data = False
                        if uncoded_ber_data is not None and len(uncoded_ber_data) == len(channeled_lambdas_nm):
                            ax_ber.semilogy(channeled_lambdas_nm, uncoded_ber_data, label=f'{modulation} BER (Uncoded)') # Label as Uncoded
                            plot_data = True


                        # Optional: Plot Q-factor on a secondary axis or a separate plot if needed
                        # For now, we just calculate and store it.

                    if plot_data: # Only set limits and legend if any data was plotted
                        ax_ber.set_xlabel("Wavelength (nm)")
                        ax_ber.set_ylabel("BER")
                        ax_ber.set_title("Final BER across Channels for Different Modulation Formats")
                        ax_ber.grid(True, which="both")
                        ax_ber.legend()


                else:
                    # print("Warning: Could not plot BER due to empty channeled wavelengths data.")
                    pass # Suppress warning
            else:
                 # print("Warning: Could not plot BER due to empty BER results.")
                 pass # Suppress warning


            if self.ber_canvas:
                self.ber_canvas.draw()


    # --- Helper Functions ---

    def calculate_wavelength_dependent_attenuation(self, cable_type_params, lambda_signal_m):
        """
        Calculates the attenuation per kilometer at a given signal wavelength (in meters)
        using a quadratic model relative to a reference wavelength (assumed 1550 nm or 1550e-9 m).
        Returns attenuation in dB/km.
        """
        alpha_min = cable_type_params.get('attenuation_per_km', 0) # in dB/km at reference wavelength
        beta_quadratic = cable_type_params.get('attenuation_slope_quadratic', 0) # in dB/km/nm^2
        signal_wavelength_nm = lambda_signal_m * 1e9 # Convert signal wavelength to nm
        min_attenuation_wavelength_nm = 1550 # Reference wavelength in nm

        wavelength_difference_nm = signal_wavelength_nm - min_attenuation_wavelength_nm
        attenuation_at_lambda = alpha_min + beta_quadratic * ((wavelength_difference_nm)**2)
        return attenuation_at_lambda if attenuation_at_lambda > 0 else 0


    def calculate_dispersion_at_wavelength(self, cable_type_params, lambda_signal_m):
        """
        Calculates the dispersion parameter (D in ps/nm/km) at a given signal wavelength (in meters)
        using a linear model relative to a reference wavelength (assumed 1550 nm or 1550e-9 m).
        Returns dispersion in ps/nm/km.
        """
        D_ref = cable_type_params.get('dispersion', 0) # in ps/nm/km at reference wavelength
        S_ref = cable_type_params.get('dispersion_slope', 0) # in ps/nm^2/km at reference wavelength
        signal_wavelength_nm = lambda_signal_m * 1e9 # Convert signal wavelength to nm
        lambda_ref_nm = 1550 # Reference wavelength in nm

        # The formula for dispersion D(lambda) is D_ref + S_ref * (lambda_signal - lambda_ref)
        # where lambda_signal and lambda_ref are in nm.
        dispersion_at_lambda = D_ref + S_ref * (signal_wavelength_nm - lambda_ref_nm)
        return dispersion_at_lambda # Returns dispersion in ps/nm/km


    def chrom_dispersion_total_with_slope(self, length, span_layout, cable_types, channeled_lambdas_m):
        """
        Calculates accumulated chromatic dispersion (in ps/nm) along the link for each channel,
        considering the span layout with multiple fiber types.
        channeled_lambdas_m are in meters.
        Returns accumulated dispersion in ps/nm at each km.
        """
        num_channels = len(channeled_lambdas_m) # Wavelengths are in meters
        accumulated_dispersion = np.zeros((num_channels, length + 1)) # Accumulated dispersion in ps/nm along the link

        # Build the list of fiber types for each kilometer of the link
        span_length = sum(seg_length for _, seg_length in span_layout) # Sum of segment lengths in the span layout

        if span_length <= 0:
             # This case should be handled by input validation before calling this function,
             # but adding a defensive check.
             print("Error: Total span length is 0 or negative in chrom_dispersion_total_with_slope. Cannot calculate dispersion.")
             return np.zeros((num_channels, length + 1))

        # Create the fiber type list for the entire link based on repeating the span layout
        # The link consists of 'length' km. EDFAs are at 'span', '2*span', ... km.
        # Each section between EDFAs (of length 'span') uses the 'span_layout'.
        # Since we validated that sum(span_layout lengths) == span, we can just repeat the span_layout.
        link_fiber_types = []
        for km in range(length):
             km_within_span = km % span_length # Kilometer within the current span layout repetition (0 to span_length-1)
             current_pos_in_span_layout = 0
             current_fiber_type = None
             for fiber_type, seg_length in span_layout:
                  if km_within_span < current_pos_in_span_layout + seg_length:
                       current_fiber_type = fiber_type
                       break
                  current_pos_in_span_layout += seg_length # Increment current_pos_in_span_layout by the segment length


             link_fiber_types.append(current_fiber_type)


        # Add a check if link_fiber_types construction resulted in None (shouldn't happen if span_layout is valid)
        if None in link_fiber_types:
             self.status_label.config(text="Internal Error: Could not determine fiber type for all kilometers.", foreground="red")
             print("Internal Error: Could not determine fiber type for all kilometers.")
             return np.zeros((num_channels, length + 1)) # Return zeros on internal error


        # Iterate through each kilometer and accumulate dispersion
        for i in range(length): # Iterate through each kilometer (from 0 to length-1)
            # Get the fiber type for the current kilometer
            # Add a check if link_fiber_types is empty or index is out of bounds (defensive)
            if not link_fiber_types or i >= len(link_fiber_types):
                 print(f"Internal Error: link_fiber_types issue at kilometer {i}.")
                 # Return current accumulated dispersion or zeros
                 return accumulated_dispersion[:, :i+1] if i > 0 else np.zeros((num_channels, length + 1))


            current_fiber_type_name = link_fiber_types[i]

            # Add a check if current_fiber_type_name exists in cable_types (should be covered by initial validation, but defensive)
            if current_fiber_type_name not in cable_types:
                 print(f"Internal Error: Fiber type '{current_fiber_type_name}' not found in cable_types during dispersion accumulation at km {i}.")
                 # Decide how to handle - maybe skip this kilometer's contribution
                 # For now, let's just use the previous accumulated dispersion for this kilometer
                 accumulated_dispersion[:, i+1] = accumulated_dispersion[:, i]
                 continue # Skip this iteration


            current_cable_type_params = cable_types[current_fiber_type_name]

            for j in range(num_channels):
                # calculate_dispersion_at_wavelength returns dispersion in ps/nm/km
                dispersion_per_km_at_lambda = self.calculate_dispersion_at_wavelength(
                    current_cable_type_params, channeled_lambdas_m[j] # channeled_lambdas in meters
                )
                accumulated_dispersion[j, i+1] = accumulated_dispersion[j, i] + dispersion_per_km_at_lambda # Add dispersion for 1 km


        return accumulated_dispersion


    def calculate_preemphasis_powers(self, channels, edfa_gain_weights_db, target_total_initial_power_dBm):
        """
        Calculates initial channel powers with pre-emphasis based on inverse EDFA gain weights.
        edfa_gain_weights_db should represent the target gain shape (e.g., span attenuation).
        """
        target_total_initial_power_mW = 10**(target_total_initial_power_dBm / 10)

        # The target gain shape is the span attenuation in dB.
        # To pre-emphasize, we want higher initial power for channels with higher attenuation.
        # This means the initial power shape should be proportional to the span attenuation in linear terms.
        initial_power_shape_linear = 10**(edfa_gain_weights_db / 10) # Linear terms of span attenuation

        # Normalize the power shape and scale to the target total initial power
        # Ensure sum is not zero or negative before division
        sum_initial_power_shape_linear = np.sum(initial_power_shape_linear)
        epsilon = 1e-20 # Small value to prevent division by zero
        if sum_initial_power_shape_linear <= epsilon:
             # print("Warning: Sum of initial power shape is zero or negative. Cannot normalize for pre-emphasis.") # Suppress warning
             # In case of issue, distribute power equally
             normalized_initial_power_mW = np.ones(channels) * (target_total_initial_power_mW / channels)
        else:
             normalized_initial_power_mW = (initial_power_shape_linear / sum_initial_power_shape_linear) * target_total_initial_power_mW

        # Convert normalized power back to dBm
        # Ensure normalized_initial_power_mW is not zero or negative before taking log10
        normalized_initial_power_mW = np.maximum(normalized_initial_power_mW, epsilon)
        channel_power_dbm = 10*np.log10(normalized_initial_power_mW)

        return channel_power_dbm


    def const_power_EDFA_with_tilt_saturation(self, multi_channel_signal_power_mW, multi_channel_noise_power_mW, const_power_dBm, channeled_lambdas_m, EDFA_noise, location, span_attenuation_dB_per_channel):
        """
        Simulates a constant total output power EDFA with gain tilt to compensate for span loss,
        including saturation and ASE noise.
        channeled_lambdas_m are in meters.
        Applies EDFA effect at a specific 'location' (kilometer index + 1).
        span_attenuation_dB_per_channel is the loss over one span for each channel.
        multi_channel_signal_power_mW and multi_channel_noise_power_mW are in mW.
        """
        channeled_lambdas_m = np.asarray(channeled_lambdas_m) # Ensure it's an array
        EDFA_noise_linear = 10**(EDFA_noise/10) # Noise figure in linear terms
        planks_const = 6.626e-34 # J*s
        c = 3e8 # Speed of light in m/s
        epsilon = 1e-20 # Small value to prevent division by zero


        # Target total output power
        target_total_output_power_dBm = const_power_dBm
        target_total_output_power_mW = 10**(target_total_output_power_dBm / 10)

        # Input power to the EDFA at the current location
        current_signal_input_mW = multi_channel_signal_power_mW[:, location].copy()
        current_noise_input_mW = multi_channel_noise_power_mW[:, location].copy()


        # Calculate the total input signal power to the EDFA
        total_signal_input_power_mW = np.sum(current_signal_input_mW)
        # Ensure total input power sum is not zero or negative for gain calculation
        total_signal_input_power_mW = np.maximum(total_signal_input_power_mW, epsilon)


        # Calculate the required gain for each channel to compensate for span loss
        # This is the ideal gain *shape*. The total gain will be adjusted for constant total output power.
        required_gain_shape_linear = 10**(span_attenuation_dB_per_channel / 10)


        # Calculate the total output power needed if applying just the required gain shape
        output_power_with_shape_mW = current_signal_input_mW * required_gain_shape_linear
        total_output_power_with_shape_mW = np.sum(output_power_with_shape_mW)


        # Calculate the scaling factor to achieve the target total output power
        # This scaling factor adjusts the overall gain level while preserving the gain shape
        scaling_factor = target_total_output_power_mW / (total_output_power_with_shape_mW + epsilon)


        # Apply the scaled gain (preserving shape) to signal and noise
        actual_linear_gain_per_channel = required_gain_shape_linear * scaling_factor

        multi_channel_signal_power_mW[:, location] = current_signal_input_mW * actual_linear_gain_per_channel
        multi_channel_noise_power_mW[:, location] = current_noise_input_mW * actual_linear_gain_per_channel # Noise is also amplified


        # Add Amplified Spontaneous Emission (ASE) noise
        # ASE noise power per channel = n_sp * h * nu * Bandwidth * Gain
        # n_sp (Spontaneous Emission Factor) is related to Noise Figure (NF) by NF = 10 * log10(n_sp * 2)
        # So n_sp = 10**(EDFA_noise/10) / 2
        n_sp_linear = 10**(EDFA_noise/10) / 2
        f_channeled = c / channeled_lambdas_m # Frequency in Hz

        # Noise bandwidth in Hz for OSNR definition (0.1 nm reference bandwidth)
        # Using the reference bandwidth (0.1 nm) converted to Hz at each channel's frequency for ASE calculation.
        # The bandwidth in Hz for a given optical bandwidth (d_lambda) is c * d_lambda / lambda^2
        reference_bandwidth_hz_per_channel = c * default_params['reference_bandwidth_nm'] / (channeled_lambdas_m**2)


        # The gain used for ASE calculation is the *actual* linear gain applied to each channel.
        power_ase_added_mW_per_channel = n_sp_linear * planks_const * f_channeled * reference_bandwidth_hz_per_channel * actual_linear_gain_per_channel * 1000 # Convert to mW
        power_ase_added_mW_per_channel[power_ase_added_mW_per_channel < 0] = 0 # Ensure noise is non-negative


        multi_channel_noise_power_mW[:, location] += power_ase_added_mW_per_channel


    def apply_gain_equalizer(self, multi_channel_signal_power_mW, multi_channel_noise_power_mW, target_total_power_dBm, location):
        """
        Applies a gain equalizer after the EDFA to flatten the signal power across channels.
        Assumes equalization to a target total power, distributing equally among channels.
        Applies equalizer effect at a specific 'location' (kilometer index + 1).
        multi_channel_signal_power_mW and multi_channel_noise_power_mW are in mW.
        """
        epsilon = 1e-20
        # Input power to the equalizer is the power at 'location' after EDFA gain and ASE noise
        current_signal_input_mW = np.asarray(multi_channel_signal_power_mW[:, location].copy())
        current_noise_input_mW = np.asarray(multi_channel_noise_power_mW[:, location].copy())
        # Note: current_signal_input_mW and current_noise_input_mW are arrays of shape (channels,)

        target_total_power_mW = 10**(target_total_power_dBm / 10)
        num_channels = len(current_signal_input_mW) # Number of channels
        if num_channels == 0:
             # print("Warning: No channels to equalize.") # Suppress warning
             return # Cannot equalize if no channels

        target_power_per_channel_mW = target_total_power_mW / num_channels # Equal power per channel

        gain_needed_linear_per_channel = np.zeros_like(current_signal_input_mW)
        # Calculate gain needed for signal power to reach target per-channel power
        non_zero_signal_mask = current_signal_input_mW > epsilon
        gain_needed_linear_per_channel[non_zero_signal_mask] = target_power_per_channel_mW / (current_signal_input_mW[non_zero_signal_mask] + epsilon)
        # For channels with zero input signal power, gain can be set to 1 (no change) or 0 (no output), depending on the model.
        # Setting to 1 ensures noise is passed through, consistent with an equalizer not adding/removing noise based on signal.
        gain_needed_linear_per_channel[~non_zero_signal_mask] = 1.0 # Or 0.0 if equalizer blocks empty channels

        # Apply the same gain to both signal and noise for each channel
        multi_channel_signal_power_mW[:, location] = current_signal_input_mW * gain_needed_linear_per_channel
        multi_channel_noise_power_mW[:, location] = current_noise_input_mW * gain_needed_linear_per_channel


    def calculate_ber(self, snr_db, modulation='QPSK', pre_encoding=None):
        """
        Calculates BER based on electrical SNR (in dB) and modulation format, with optional FEC.
        This function assumes the input `snr_db` is the electrical SNR *per symbol* after the receiver filter.
        Adjusts SNR based on bits per symbol for Q-factor calculation to compare different formats fairly.
        Returns both BER and the calculated Q-factor.
        """
        snr_linear = 10**(snr_db / 10)
        snr_linear = np.maximum(snr_linear, 1e-10) # Prevent log/sqrt issues with zero or negative SNR

        Q = np.zeros_like(snr_linear) # Initialize Q-factor array
        uncoded_ber = np.zeros_like(snr_linear) # Initialize uncoded_ber array

        modulation = modulation.lower()

        # Determine bits per symbol for the modulation format
        bits_per_symbol_dict = {
            'bpsk': 1,
            'qpsk': 2,
            '8psk': 3,
            '16qam': 4,
            'nrz-ook': 1, # OOK transmits 1 bit per symbol
            'rz-ook': 1, # OOK transmits 1 bit per symbol
            'dpsk': 1 # DPSK transmits 1 bit per symbol
        }
        bits_per_symbol = bits_per_symbol_dict.get(modulation, 1) # Default to 1 for unknown


        # Calculate electrical SNR per bit (Eb/N0) from the provided electrical SNR per symbol
        # SNR_linear (per symbol) = Eb/N0_linear * bits_per_symbol
        # So, Eb/N0_linear = SNR_linear / bits_per_symbol
        eb_n0_linear = snr_linear / bits_per_symbol
        # Ensure Eb/N0 is not zero or negative
        eb_n0_linear = np.maximum(eb_n0_linear, 1e-10)


        # Calculate Q-factor and Uncoded BER based on modulation format and Eb/N0_linear
        # These are theoretical Q and BER in AWGN for a given Eb/N0
        if modulation == 'bpsk':
            # For BPSK, Q = sqrt(2 * Eb/N0)
            Q = np.sqrt(2 * eb_n0_linear)
            uncoded_ber = 0.5 * scipy.special.erfc(Q / np.sqrt(2))
        elif modulation == 'qpsk':
            # For QPSK, Q = sqrt(2 * Eb/N0) (BER per bit)
            Q = np.sqrt(2 * eb_n0_linear)
            uncoded_ber = 0.5 * scipy.special.erfc(Q / np.sqrt(2))
        elif modulation == '8psk':
            # For 8PSK, approximate BER based on distance in constellation
            # A common approximation for MPSK (M>2) BER is 1/log2(M) * erfc(sqrt(log2(M) * Eb/N0) * sin(pi/M))
            # Let's use this directly for BER calculation
            M = 8
            # Argument for erfc: sqrt(log2(M) * Eb/N0_linear) * sin(pi/M)
            erfc_arg = np.sqrt(bits_per_symbol * eb_n0_linear) * np.sin(np.pi / M)
            uncoded_ber = (1.0 / bits_per_symbol) * scipy.special.erfc(erfc_arg) # BER per bit
            # Q-factor can be related to the argument: Q = sqrt(2) * erfc_arg
            Q = np.sqrt(2) * erfc_arg
        elif modulation == '16qam':
             # Approximation for M-QAM BER based on distance in constellation
             # For M-QAM, BER is approx 1/log2(M) * erfc(sqrt(3 * log2(M) * Eb/N0 / (M - 1)))
             M = 16
             # Argument for erfc: sqrt(3 * log2(M) * Eb/N0_linear / (M - 1))
             erfc_arg = np.sqrt(3 * bits_per_symbol * eb_n0_linear / (M - 1))
             uncoded_ber = (1.0 / bits_per_symbol) * scipy.special.erfc(erfc_arg) # BER per bit
             # Q-factor can be related to the argument: Q = sqrt(2) * erfc_arg
             Q = np.sqrt(2) * erfc_arg
        elif modulation == 'nrz-ook' or modulation == 'rz-ook':
            # Simplified Q for OOK (assuming ideal detection and extinction ratio), based on Eb/N0
            # For OOK, Q = sqrt(Eb/N0)
            Q = np.sqrt(eb_n0_linear)
            uncoded_ber = 0.5 * scipy.special.erfc(Q / np.sqrt(2))
        elif modulation == 'dpsk':
            # Correct Uncoded BER for ideal DPSK in AWGN = 0.5 * exp(-Eb/N0_linear)
            uncoded_ber = 0.5 * np.exp(-eb_n0_linear)
            # Q-factor for DPSK is often defined differently or not directly used with erfc in the same way.
            # A common approach is to find the Q that would give the same BER from the BPSK/QPSK formula.
            # BER_DPSK = 0.5 * exp(-Eb/N0)
            # BER_BPSK = 0.5 * erfc(Q_BPSK/sqrt(2))
            # Setting BER_DPSK = BER_BPSK: 0.5 * exp(-Eb/N0_DPSK) = 0.5 * erfc(Q_DPSK/sqrt(2))
            # erfc(Q_DPSK/sqrt(2)) = exp(-Eb/N0_DPSK)
            # Q_DPSK/sqrt(2) = erfcinv(exp(-Eb/N0_DPSK))
            # Q_DPSK = sqrt(2) * erfcinv(exp(-eb_n0_linear))
            # Note: erfcinv is the inverse of erfc. scipy.special.erfcinv might not be directly available or stable.
            # Let's use the relationship Q^2 = 2 * Eb/N0 for BPSK and map DPSK BER to an equivalent BPSK Q.
            # exp(-Eb/N0_linear) = erfc(Q_equiv / sqrt(2))
            # Q_equiv = sqrt(2) * erfcinv(exp(-eb_n0_linear)) # This is difficult to compute directly.
            # Alternatively, the Q-factor for DPSK is sometimes related to sqrt(Eb/N0) for moderate to high SNR.
            Q = np.sqrt(eb_n0_linear) # Using sqrt(Eb/N0) as an approximation for Q for DPSK
        else:
            # Default to BPSK calculation if modulation not recognized, using Eb/N0
            bits_per_symbol = 1 # Assume 1 for unknown
            eb_n0_linear = snr_linear / bits_per_symbol
            eb_n0_linear = np.maximum(eb_n0_linear, 1e-10)
            Q = np.sqrt(2 * eb_n0_linear)
            uncoded_ber = 0.5 * scipy.special.erfc(Q / np.sqrt(2))


        # Ensure Q is non-negative for erfc and other calculations
        Q = np.maximum(Q, 0)
        # Ensure uncoded_ber is within valid range [0, 1] before applying FEC
        uncoded_ber = np.clip(uncoded_ber, 0, 1)



        return uncoded_ber, Q # Return both BER and the calculated Q-factor



if __name__ == "__main__":
    root = tk.Tk()
    gui = OpticalFiberSimulatorGUI(root)
    root.mainloop()

TclError: no display name and no $DISPLAY environment variable