In [None]:
# This file is to produce the wave documentation data and plot ir
# 6 wave probes were placed along the pool, 5 in the middle, and one between the middle and the wall at the half way to find transverse waves
# All wave documentation data is in C:\Users\HP\OneDrive - NTNU\Desktop\Master\Code\PostProcess\Resultater\WaveDocumentation
#The tank is 25m long, 2.5m wide and 0.7m deep. Code should take reflection


import numpy as np
import matplotlib.pyplot as plt
import math
from scipy.optimize import fsolve
import os # Added for potential file system operations, though not used in this snippet
import pandas as pd # For handling experimental data
import scipy.signal # For Butterworth filter
from apread import APReader # Added for reading .bin files

wavedocumentation_path = r"H:\CodeLånePC\Resultater\WaveDocumentation"

water_depth = 0.7 #m

# Store wave probe x-positions in a dictionary for easier access
wave_probe_positions_x = {
    "WP3": 5,
    "WP4": 9,
    "WP5": 11,
    "WP6": 13,
    "WP7": 17,
    "WP8": 11,
}

#Distance from wall (y-coordinate, tank width is 2.5m)
WP3_y = 1.25 #m
WP4_y = 1.25 #m
WP5_y = 1.25 #m
WP6_y = 1.25 #m 
WP7_y = 1.25 #m
WP8_y = 0.625 #m #Off centre

# Physical constants
g = 9.81  # m/s^2 (acceleration due to gravity)
L_tank = 22.0  # m (length of the tank)

# --- Data Loading Functions ---
def search_file(file_name, root_folder):
    """Searches for a file in the root_folder and its subdirectories."""
    for r, d, f_list in os.walk(root_folder):
        if file_name in f_list:
            return os.path.join(r, file_name)
    raise FileNotFoundError(f"File '{file_name}' not found in '{root_folder}' or its subdirectories.")


def bin_to_dataframe(full_file_path):
    """
    Reads a .bin file using APReader and converts it to a pandas DataFrame.
    Identifies channels by names. Looks for a 'Time' channel and wave probe channels (WP3-WP8).
    """
    try:
        reader = APReader(full_file_path)
    except Exception as e:
        raise IOError(f"Error opening or reading file {full_file_path} with APReader: {e}")

    data_dict = {}
    raw_channel_names = []
    for ch in reader.Channels:
        try:
            name_parts = str(ch).split('"')
            if len(name_parts) > 1:
                raw_channel_names.append(name_parts[1])
            else:
                raw_channel_names.append(str(ch.Name))
        except Exception:
            try:
                raw_channel_names.append(str(ch.Name))
            except AttributeError:
                raw_channel_names.append(f"UnnamedChannel_{len(raw_channel_names)}")

    for i, ch in enumerate(reader.Channels):
        data_dict[raw_channel_names[i]] = ch.data

    time_key_found = None
    temp_df_keys = list(data_dict.keys())
    for key in temp_df_keys:
        if "time" in key.lower():
            time_key_found = key
            break
    
    if not time_key_found:
        raise ValueError(f"No 'Time' channel found in {full_file_path}. Available channels: {temp_df_keys}")

    df = pd.DataFrame(data_dict)

    desired_wp_columns = [wp for wp in wave_probe_positions_x.keys() if wp in df.columns]
    final_columns_to_keep = [time_key_found] + desired_wp_columns
    
    df_subset = df[final_columns_to_keep]
    df_subset = df_subset.rename(columns={time_key_found: "Time"})
    
    return df_subset

# --- End Data Loading Functions --
# 

def solve_dispersion_for_k(period, water_depth, g_const):
    """
    Solves the dispersion relation for the wave number k.
    omega^2 = g * k * tanh(k * h)
    Args:
        period (float): Wave period in seconds.
        water_depth (float): Water depth in meters.
        g_const (float): Acceleration due to gravity.
    Returns:
        float: Wave number k (rad/m), or np.nan if not solvable.
    """
    if period <= 0 or water_depth <= 0:
        return np.nan
    
    omega = 2 * np.pi / period
    
    # Equation to solve: g * k * tanh(k * h) - omega^2 = 0
    def dispersion_eq(k_val):
        if k_val <= 0: # k must be positive
            return np.inf # Return a large number if k is non-physical to guide solver
        return g_const * k_val * np.tanh(k_val * water_depth) - omega**2

    # Initial guess for k (deep water approximation, or other reasonable guess)
    k_guess = omega**2 / g_const  # Deep water
    if k_guess * water_depth < 1: # If more like shallow water, adjust guess
        k_guess_shallow = omega / np.sqrt(g_const * water_depth)
        # Check if shallow water guess is more appropriate or if deep water guess is too small
        if k_guess <= 0.01 : k_guess = k_guess_shallow

    if k_guess <= 0: k_guess = 0.01 # Ensure guess is positive

    try:
        k_solution, infodict, ier, mesg = fsolve(dispersion_eq, k_guess, full_output=True)
        if ier == 1 and k_solution[0] > 0: # Check if fsolve converged and k is positive
            return k_solution[0]
        else:
            # Try a different guess or a bounded solver if fsolve fails.
            # For simplicity, we'll return NaN on basic fsolve failure here.
            # print(f"Warning: Dispersion solver did not converge for T={period:.2f}s. Message: {mesg}")
            return np.nan
    except Exception as e:
        # print(f"Error in dispersion solver for T={period:.2f}s: {e}")
        return np.nan

def extract_variable_from_filename(file_name, variable):
    # Expecting format: date-period-steepness(info)#test.bin
    # Example: 1701-0_525-40(WaveDoc)#1.bin
    # parts[0] = date (1701)
    # parts[1] = period (0_525)
    # parts[2] = steepness(info)#test.bin (40(WaveDoc)#1.bin)
    
    parts = file_name.split('-', 2) # Split into at most 3 parts
    if len(parts) < 3:
        raise ValueError(f"Invalid file name format. Expected 'date-period-steepness(info)#test.bin', got: {file_name}")

    date = parts[0]
    period_str = parts[1].replace("_", ".") # Period is the second part
    
    # The third part contains steepness, optional info, and test number
    steepness_and_test_part = parts[2]
    
    steepness_raw, hash_separator, test_part_with_ext = steepness_and_test_part.partition('#')
    
    if not hash_separator: # '#' was not found
        # This implies the format might be date-period-steepness.bin (no test number)
        # or the structure is different than expected.
        # For now, assume if # is missing, the rest is steepness_raw and test is unknown.
        test = "Unknown"
        # Check if .bin is present in steepness_raw and remove it
        if steepness_raw.endswith(".bin"):
            steepness_raw = steepness_raw[:-4]
        # If we are here, it means there was no '#' separating steepness_raw from test_part
        # So, test_part_with_ext would be empty from partition.
    else:
        test = test_part_with_ext.replace(".bin", "")

    # Parse steepness_val and info_val from steepness_raw
    if '(' in steepness_raw and ')' in steepness_raw:
        i = steepness_raw.find('(')
        j = steepness_raw.find(')', i)
        steepness_val_str = steepness_raw[:i]
        info_val = steepness_raw[i+1:j]
    elif '-' in steepness_raw: # Check for "value-info" if not "(info)"
        # This assumes a hyphen within the steepness_raw part is for info.
        # e.g. "S0_02-SomeInfo" -> steepness_val_str="S0_02", info_val="SomeInfo"
        potential_steep_parts = steepness_raw.split('-', 1)
        if len(potential_steep_parts) == 2:
             steepness_val_str = potential_steep_parts[0]
             info_val = potential_steep_parts[1]
        else: # Only one part, or leading hyphen
            steepness_val_str = steepness_raw
            info_val = None
    else:
        steepness_val_str = steepness_raw
        info_val = None

    # Clean the steepness value string (e.g. S0_02 -> 0.02)
    # A simple approach: if it starts with S and then a digit or '_', remove S.
    if steepness_val_str and isinstance(steepness_val_str, str) and \
       steepness_val_str.startswith(('S', 's')) and len(steepness_val_str) > 1 and \
       (steepness_val_str[1].isdigit() or steepness_val_str[1] == '_'):
        steepness_val_str = steepness_val_str[1:]
    
    steepness_val_str = steepness_val_str.replace("_", ".") if steepness_val_str else None

    var = variable.lower()
    if var == "date":
        return date
    elif var == "velocity":
        return np.nan # Velocity is not in this filename format
    elif var == "period":
        try:
            return float(period_str)
        except (ValueError, TypeError):
            return period_str # Return as string if not floatable
    elif var == "steepness":
        try:
            return float(steepness_val_str)
        except (ValueError, TypeError):
            return steepness_val_str # Return as string if not floatable
    elif var == "info":
        return info_val
    elif var in ("test", "testnumber"):
        if test == "Unknown": return test
        try:
            return int(test)
        except ValueError:
            return test # Return as string if not intable
    else:
        raise ValueError(f"Invalid variable requested: {variable}.")


def calculate_average_steady_state_amplitude(full_file_path, channel_name, 
                                           wave_period_T=None, 
                                           steepness_val_from_filename=None, # New: e.g., 40.0
                                           steady_state_start_time=20.0, 
                                           filter_order=4, cutoff_freq_hz=5.0,
                                           initial_skip_periods=15, # Fallback
                                           end_buffer_periods=10,   # Fallback
                                           target_amplitude_factor_for_start=0.3, 
                                           num_peaks_skip_start=10,
                                           num_peaks_skip_end=0, # Changed from 10 to 0
                                           reflection_time_delay_sec=10.0): # New parameter for delay
    """
    Calculates the average of steady-state amplitudes for a given wave probe signal.
    The analysis window is dynamically determined by:
    1. Start: Finding when signal reaches target_amplitude_factor_for_start * input_amplitude, then skipping num_peaks_skip_start peaks.
    2. End: Ending num_peaks_skip_end peaks before calculated reflection.
    Falls back to time-based windowing if dynamic determination fails.

    Args:
        full_file_path (str): The full path to the .bin data file.
        channel_name (str): The name of the wave probe channel (e.g., "WP3").
        wave_period_T (float, optional): Nominal wave period in seconds.
        steepness_val_from_filename (float, optional): Steepness value (e.g., 40.0 for S_HL*1000) from filename.
        steady_state_start_time (float): Fallback start time.
        filter_order (int): Order of the Butterworth filter.
        cutoff_freq_hz (float): Cutoff frequency for the Butterworth low-pass filter.
        initial_skip_periods (int): Fallback: Number of wave periods to skip from t=0.
        end_buffer_periods (int): Fallback: Number of wave periods before reflection to end analysis.
        target_amplitude_factor_for_start (float): Factor of input_amplitude to define ramp-up threshold.
        num_peaks_skip_start (int): Number of peaks to skip after ramp-up threshold is met.
        num_peaks_skip_end (int): Number of peaks to skip before reflection.
        reflection_time_delay_sec (float): Seconds to add to the calculated reflection arrival time.

    Returns:
        float: The average of the steady-state amplitudes, or np.nan if not calculable.
    """
    try:
        df = bin_to_dataframe(full_file_path)
    except Exception as e:
        # print(f"Debug: Error loading data for {channel_name} from {os.path.basename(full_file_path)}: {e}")
        return np.nan

    if channel_name not in df.columns:
        return np.nan
    if "Time" not in df.columns:
        return np.nan

    time_signal = df["Time"].values
    wave_signal = df[channel_name].values

    if len(time_signal) < 10 or len(wave_signal) < 10:
        return np.nan

    fs = 1.0 / np.mean(np.diff(time_signal))
    if not (fs > 0 and np.isfinite(fs)):
        return np.nan
    
    nyquist_freq = 0.5 * fs
    actual_cutoff_freq_hz = cutoff_freq_hz
    if cutoff_freq_hz >= nyquist_freq:
        actual_cutoff_freq_hz = nyquist_freq * 0.95 
        if actual_cutoff_freq_hz <= 0.01:
            return np.nan 

    Wn = actual_cutoff_freq_hz / nyquist_freq
    if not (0 < Wn < 1):
        filtered_signal = wave_signal
    else:
        b, a = scipy.signal.butter(filter_order, Wn, btype='low', analog=False)
        if len(wave_signal) > 3 * filter_order: 
            try:
                filtered_signal = scipy.signal.filtfilt(b, a, wave_signal)
            except ValueError: 
                filtered_signal = wave_signal 
        else:
            filtered_signal = wave_signal

    filtered_signal = filtered_signal - np.mean(filtered_signal) # Detrend

    # --- Calculate Input Amplitude for dynamic windowing ---
    calculated_input_amplitude = np.nan
    if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0 and \
       steepness_val_from_filename is not None and isinstance(steepness_val_from_filename, (float, int)):
        k_val_ia = solve_dispersion_for_k(wave_period_T, water_depth, g)
        if k_val_ia is not None and not np.isnan(k_val_ia) and k_val_ia > 0:
            wavelength_ia = 2 * np.pi / k_val_ia
            actual_steepness_HL_ia = steepness_val_from_filename / 1000.0
            calculated_input_amplitude = actual_steepness_HL_ia * wavelength_ia / 2.0

    # --- Determine Analysis Start Time (analysis_start_t) ---
    analysis_start_t = -1 # Flag for not set
    
    # Try dynamic start time
    if not np.isnan(calculated_input_amplitude) and calculated_input_amplitude > 0:
        target_signal_level_for_start = target_amplitude_factor_for_start * calculated_input_amplitude
        # Find first time signal exceeds this (positive) level
        indices_above_target = np.where(filtered_signal >= target_signal_level_for_start)[0]
        
        if len(indices_above_target) > 0:
            first_index_above_target = indices_above_target[0]
            
            # Find peaks from this point onwards
            min_peak_dist_start = int(fs * wave_period_T * 0.4) if (wave_period_T and wave_period_T > 0 and fs > 0) else int(fs * 0.1)
            if min_peak_dist_start < 1: min_peak_dist_start = 1
            
            peaks_after_threshold_indices_rel, _ = scipy.signal.find_peaks(
                filtered_signal[first_index_above_target:], 
                height=0.0005, # Keep a minimal height threshold
                distance=min_peak_dist_start
            )
            peaks_after_threshold_indices_abs = peaks_after_threshold_indices_rel + first_index_above_target
            
            if len(peaks_after_threshold_indices_abs) > num_peaks_skip_start:
                # The (num_peaks_skip_start)-th peak in 0-indexed array is the one to start AFTER
                # So, the (num_peaks_skip_start)-th index is the (num_peaks_skip_start + 1)-th peak
                start_peak_abs_index = peaks_after_threshold_indices_abs[num_peaks_skip_start]
                analysis_start_t = time_signal[start_peak_abs_index]
                # print(f"Debug ({channel_name}): Dynamic start time set to {analysis_start_t:.2f}s (peak after {num_peaks_skip_start} skips).")

    # Fallback start time if dynamic failed
    if analysis_start_t == -1:
        if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0:
            analysis_start_t = initial_skip_periods * wave_period_T
            # print(f"Debug ({channel_name}): Fallback start time (T-based): {analysis_start_t:.2f}s.")
        else:
            analysis_start_t = steady_state_start_time # Absolute fallback
            # print(f"Debug ({channel_name}): Fallback start time (fixed): {analysis_start_t:.2f}s.")

    # --- Calculate reflection arrival time for the current channel (as before) ---
    reflection_arrival_time_current_channel = np.nan
    if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0:
        k_val = solve_dispersion_for_k(wave_period_T, water_depth, g)
        if k_val is not None and not np.isnan(k_val) and k_val > 0:
            omega = 2 * np.pi / wave_period_T
            C = omega / k_val 
            n_val = 0.5 * (1 + (2 * k_val * water_depth) / np.sinh(2 * k_val * water_depth))
            Cg = n_val * C 
            if Cg > 1e-6:
                wp_x_position = wave_probe_positions_x.get(channel_name)
                if wp_x_position is not None:
                    time_to_reach_probe = wp_x_position / Cg
                    time_probe_to_wall_and_back = (2 * (L_tank - wp_x_position)) / Cg
                    if time_probe_to_wall_and_back >= 0:
                        reflection_arrival_time_current_channel = time_to_reach_probe + time_probe_to_wall_and_back
                        if not np.isnan(reflection_arrival_time_current_channel): # Add delay if valid
                            reflection_arrival_time_current_channel += reflection_time_delay_sec

    # --- Determine Analysis End Time (analysis_end_t) ---
    analysis_end_t = -1 # Flag for not set

    # Try dynamic end time
    if not np.isnan(reflection_arrival_time_current_channel):
        idx_reflection_arrival = np.searchsorted(time_signal, reflection_arrival_time_current_channel)
        if idx_reflection_arrival > 0: # Ensure reflection is not at the very beginning
            min_peak_dist_end = int(fs * wave_period_T * 0.4) if (wave_period_T and wave_period_T > 0 and fs > 0) else int(fs * 0.1)
            if min_peak_dist_end < 1: min_peak_dist_end = 1

            peaks_before_reflection_indices_abs, _ = scipy.signal.find_peaks(
                filtered_signal[:idx_reflection_arrival],
                height=0.0005,
                distance=min_peak_dist_end
            )
            if len(peaks_before_reflection_indices_abs) > num_peaks_skip_end:
                # End at the peak that is (num_peaks_skip_end + 1) from the last peak before reflection
                # e.g. if num_peaks_skip_end = 5, use index -6 (6th from last)
                end_peak_abs_index = peaks_before_reflection_indices_abs[-(num_peaks_skip_end + 1)]
                analysis_end_t = time_signal[end_peak_abs_index]
                # print(f"Debug ({channel_name}): Dynamic end time set to {analysis_end_t:.2f}s (peak before {num_peaks_skip_end} skips from reflection).")

    # Fallback end time if dynamic failed or no reflection
    if analysis_end_t == -1:
        fallback_end_t = time_signal[-1]
        if not np.isnan(reflection_arrival_time_current_channel):
            if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0:
                potential_end_t = reflection_arrival_time_current_channel - (end_buffer_periods * wave_period_T)
                fallback_end_t = min(fallback_end_t, potential_end_t)
            else:
                fallback_end_t = min(fallback_end_t, reflection_arrival_time_current_channel)
        analysis_end_t = fallback_end_t
        # print(f"Debug ({channel_name}): Fallback end time set to {analysis_end_t:.2f}s.")


    # Ensure analysis_end_t is not before analysis_start_t
    if analysis_end_t <= analysis_start_t:
        # print(f"Debug ({channel_name}): Invalid window. End time {analysis_end_t:.2f}s <= Start time {analysis_start_t:.2f}s.")
        return np.nan

    start_index_steady = np.searchsorted(time_signal, analysis_start_t)
    abs_end_index_for_analysis = np.searchsorted(time_signal, analysis_end_t, side='right') 
    
    if start_index_steady >= abs_end_index_for_analysis:
        # print(f"Debug ({channel_name}): No valid analysis window after indexing. Start_idx {start_index_steady}, End_idx {abs_end_index_for_analysis}. Times: Start {analysis_start_t:.2f}s, End {analysis_end_t:.2f}s")
        return np.nan
    
    steady_signal = filtered_signal[start_index_steady:abs_end_index_for_analysis]
    
    if len(steady_signal) < 2: 
        return np.nan

    min_peak_height = 0.0005 
    
    if wave_period_T is not None and wave_period_T > 0 and fs > 0:
        min_peak_distance = int(fs * wave_period_T * 0.4) 
    else:
        if actual_cutoff_freq_hz > 0:
             min_peak_distance = int(fs / (actual_cutoff_freq_hz * 4)) 
        else: 
            min_peak_distance = int(fs * 0.1) 

    if min_peak_distance < 1: min_peak_distance = 1

    peaks_indices, _ = scipy.signal.find_peaks(steady_signal, height=min_peak_height, distance=min_peak_distance)
    
    if len(peaks_indices) == 0:
        # print(f"Debug ({channel_name}): No peaks found in the final analysis window {analysis_start_t:.2f}s - {analysis_end_t:.2f}s.")
        return np.nan
        
    amplitudes = steady_signal[peaks_indices]
    # num_peaks_to_average = 5 # This is removed, average all peaks in window
    
    # if len(amplitudes) < num_peaks_to_average: # This check is no longer relevant in this way
    #     if len(amplitudes) == 0: return np.nan 
    
    # amplitudes_to_average = amplitudes[:num_peaks_to_average] # Average all
    average_amplitude = np.mean(amplitudes) # Average all found peaks
    
    return average_amplitude


def create_dataframe():
    """
    Creates a DataFrame by processing all .bin files in the wavedocumentation_path.
    Extracts metadata from filenames and calculates average steady-state amplitudes for each wave probe.
    Also calculates inputAmplitude based on period and steepness.

    Returns:
        pandas.DataFrame: DataFrame containing the processed data.
    """
    all_data_rows = []
    
    if not os.path.isdir(wavedocumentation_path):
        print(f"Error: Directory not found - {wavedocumentation_path}")
        return pd.DataFrame()

    print(f"Processing files in: {wavedocumentation_path}")
    files_to_process = [f for f in os.listdir(wavedocumentation_path) if f.endswith(".bin")]
    print(f"Found {len(files_to_process)} .bin files.")

    for i, filename in enumerate(files_to_process):
        print(f"Processing file {i+1}/{len(files_to_process)}: {filename}")
        full_file_path = os.path.join(wavedocumentation_path, filename)
        
        row_data = {}
        period_value = None
        steepness_from_file = None
        try:
            row_data["filename"] = filename
            row_data["date"] = extract_variable_from_filename(filename, "date")
            row_data["velocity"] = extract_variable_from_filename(filename, "velocity") # Will be NaN
            period_value = extract_variable_from_filename(filename, "period")
            row_data["period"] = period_value
            steepness_from_file = extract_variable_from_filename(filename, "steepness")
            row_data["steepness"] = steepness_from_file # This is the scaled value, e.g., 40.0
            row_data["info"] = extract_variable_from_filename(filename, "info")
            row_data["test_number"] = extract_variable_from_filename(filename, "testnumber")
        except ValueError as e:
            print(f"  Skipping file {filename} due to error extracting variables: {e}")
            continue

        # Calculate inputAmplitude
        input_amplitude_calculated = np.nan
        if isinstance(period_value, (int, float)) and period_value > 0 and \
           isinstance(steepness_from_file, (int, float)):
            
            k_val = solve_dispersion_for_k(period_value, water_depth, g)
            
            if k_val is not None and not np.isnan(k_val) and k_val > 0:
                wavelength = 2 * np.pi / k_val
                # Assuming steepness_from_file is S_HL * 1000 (e.g., 40 means S_HL = 0.04)
                # S_HL = H/L = 2a/L
                # a = (steepness_from_file / 1000.0) * wavelength / 2.0
                actual_steepness_HL = steepness_from_file / 1000.0
                input_amplitude_calculated = actual_steepness_HL * wavelength / 2.0
            else:
                # print(f"  Could not calculate k for P={period_value}, S={steepness_from_file} in {filename}")
                pass # k_val is nan or invalid
        
        row_data["inputAmplitude"] = input_amplitude_calculated

        for wp_name in wave_probe_positions_x.keys():
            avg_amp = calculate_average_steady_state_amplitude(
                full_file_path, 
                wp_name,    
                wave_period_T=period_value if isinstance(period_value, (int, float)) else None,
                steepness_val_from_filename=steepness_from_file if isinstance(steepness_from_file, (float, int)) else None, # New
                steady_state_start_time=20.0, 
                filter_order=4,             
                cutoff_freq_hz=5.0,
                initial_skip_periods=15, 
                end_buffer_periods=10,
                # Defaults for new params are used in function def if not specified here
                # target_amplitude_factor_for_start=0.5,
                # num_peaks_skip_start=10,
                # num_peaks_skip_end=5
            )
            row_data[f"{wp_name}_Avg_Amp"] = avg_amp
        
        all_data_rows.append(row_data)
            
    if not all_data_rows:
        print("No .bin files were successfully processed or found in the directory.")
        return pd.DataFrame()

    df = pd.DataFrame(all_data_rows)
    print("\nDataFrame created successfully.")
    # Example: print summary if you want to see it when script runs
    # print(df.info())
    # print(df.head())
    return df

def test_plot_average_steady_state_amplitude(full_file_path, channel_name, 
                                           wave_period_T=None, 
                                           steepness_val_from_filename=None, # New
                                           steady_state_start_time=20.0, 
                                           filter_order=4, cutoff_freq_hz=5.0,
                                           initial_skip_periods=15, end_buffer_periods=10,
                                           target_amplitude_factor_for_start=0.3, 
                                           num_peaks_skip_start=10,             # New for plot mirroring
                                           num_peaks_skip_end=0,              # Changed from 10 to 0
                                           reflection_time_delay_sec_plot=10.0): # New parameter for plot delay
    """
    Visually tests signal processing steps with a three-plot layout:
    1. Raw and Filtered signal with general WP8 reflection context.
    2. WP8-specific "ideal" sample window (after transients, before WP8 reflection) & its peak analysis.
       Shows full detrended signal in background, with analysis window highlighted.
    3. Visualization of the data and peaks used by `calculate_average_steady_state_amplitude` for the current `channel_name`.
       Shows full detrended signal in background, with analysis window highlighted.
    Reflection times are calculated from t=0 (wave generation) to reflection arrival at the probe.
    Args:
        full_file_path (str): The full path to the .bin data file.
        channel_name (str): The name of the wave probe channel (e.g., "WP3").
        wave_period_T (float, optional): Nominal wave period in seconds.
        steepness_val_from_filename (float, optional): Steepness value (e.g., 40.0) from filename for input_amp calc.
        steady_state_start_time (float): Time in seconds to consider for start of steady state.
        filter_order (int): Order of the Butterworth filter.
        cutoff_freq_hz (float): Cutoff frequency for the Butterworth low-pass filter.
        initial_skip_periods (int): Fallback: Number of wave periods to skip from t=0 (for Plot 3).
        end_buffer_periods (int): Fallback: Number of wave periods before reflection to end analysis (for Plot 3).
        target_amplitude_factor_for_start (float): Factor for dynamic start threshold.
        num_peaks_skip_start (int): Peaks to skip for dynamic start.
        num_peaks_skip_end (int): Peaks to skip for dynamic end.
        reflection_time_delay_sec_plot (float): Seconds to add to calculated reflection time for plotting.
    """
    print(f"\n--- Test Plotting for {channel_name} from {os.path.basename(full_file_path)} ---")
    try:
        df = bin_to_dataframe(full_file_path)
    except Exception as e:
        print(f"Error loading data: {e}")
        return

    if channel_name not in df.columns or "Time" not in df.columns:
        print(f"Channel {channel_name} or Time not found in DataFrame. Available columns: {df.columns.tolist()}")
        return

    time_signal = df["Time"].values
    wave_signal = df[channel_name].values # This ensures data for the correct channel is used.

    # Calculate fs early for use in diagnostic print
    fs = 0 
    if len(time_signal) > 1:
        fs = 1.0 / np.mean(np.diff(time_signal))
    else:
        print("Warning: Cannot calculate sampling frequency, time_signal too short.")
        # Potentially return or handle error, for now, fs will be 0 if not calculable

    # Diagnostic print:
    print(f"Successfully loaded data for {channel_name}. First 5 raw values: {wave_signal[:5]}")
    # Check if initial part of signal is mostly constant or has few unique values
    # Ensure fs is valid before using it for slicing
    if fs > 0 and len(wave_signal) > 5 : 
        slice_end_index = min(len(wave_signal), int(fs * 2)) # Ensure slice doesn't exceed array bounds
        if slice_end_index > 0 and len(np.unique(wave_signal[:slice_end_index])) < 10:
            print(f"Warning: Initial part of signal for {channel_name} (up to {slice_end_index} samples / ~2s) has very few unique values. This might indicate an issue, a very stable signal, or a delay before wave arrival.")
    elif len(wave_signal) > 5 and len(np.unique(wave_signal[:5])) < 3 : # Fallback check if fs is not valid
        print(f"Warning: First 5 samples for {channel_name} have very few unique values. This might indicate an issue or a very stable signal.")


    if len(time_signal) < 10:
        print("Insufficient data points.")
        return

    plt.figure(figsize=(15, 15)) # Adjusted figure size for clarity

    # --- Common Calculations ---
    # fs is already calculated above
    
    filtered_signal_data = wave_signal 
    actual_cutoff_for_label = cutoff_freq_hz 

    if not (fs > 0 and np.isfinite(fs)):
        print(f"Invalid sampling frequency ({fs:.2f} Hz). Using unfiltered signal.")
    else: 
        nyquist_freq = 0.5 * fs
        processing_cutoff = cutoff_freq_hz 
        if processing_cutoff >= nyquist_freq:
            adjusted_cutoff = nyquist_freq * 0.95
            actual_cutoff_for_label = adjusted_cutoff 
            if adjusted_cutoff <= 0.01:
                print(f"Adjusted cutoff ({adjusted_cutoff:.2f} Hz) too low. Using unfiltered signal.")
                processing_cutoff = -1 
            else: 
                processing_cutoff = adjusted_cutoff
        
        if processing_cutoff > 0: 
            Wn = processing_cutoff / nyquist_freq
            if not (0 < Wn < 1):
                print(f"Invalid Wn ({Wn:.2f}). Using unfiltered signal.")
            else:
                b, a = scipy.signal.butter(filter_order, Wn, btype='low', analog=False)
                if len(wave_signal) > 3 * filter_order: 
                    try:
                        filtered_signal_data = scipy.signal.filtfilt(b, a, wave_signal) 
                    except ValueError as ve:
                         print(f"filtfilt error: {ve}. Using unfiltered signal.")
                else:
                    print(f"Signal too short for filtfilt. Using unfiltered signal.")
        else: 
             print(f"Not filtering due to invalid cutoff ({processing_cutoff}).")

    detrended_filtered_signal = filtered_signal_data - np.mean(filtered_signal_data)

    # Calculate reflection time for WP8 (general context)
    reflection_arrival_time_WP8_context = np.nan
    if wave_period_T and wave_period_T > 0 and "WP8" in wave_probe_positions_x:
        k_val_wp8 = solve_dispersion_for_k(wave_period_T, water_depth, g)
        if k_val_wp8 and not np.isnan(k_val_wp8) and k_val_wp8 > 0:
            omega_wp8 = 2 * np.pi / wave_period_T
            C_wp8 = omega_wp8 / k_val_wp8
            n_wp8 = 0.5 * (1 + (2 * k_val_wp8 * water_depth) / np.sinh(2 * k_val_wp8 * water_depth))
            Cg_wp8 = n_wp8 * C_wp8
            if Cg_wp8 > 1e-6:
                wp8_x_pos = wave_probe_positions_x["WP8"]
                time_to_reach_WP8 = wp8_x_pos / Cg_wp8
                time_WP8_to_wall_and_back = (2 * (L_tank - wp8_x_pos)) / Cg_wp8
                if time_WP8_to_wall_and_back >= 0: # Ensure WP8 is not considered beyond the tank end
                    reflection_arrival_time_WP8_context = time_to_reach_WP8 + time_WP8_to_wall_and_back
                    print(f"Context: WP8 reflection (from t=0 at wavemaker) expected around {reflection_arrival_time_WP8_context:.2f}s (Cg={Cg_wp8:.2f}m/s).")
                else:
                    print(f"Context: WP8 ({wp8_x_pos}m) is at or beyond tank end ({L_tank}m) for reflection path calculation.")
            else:
                print("Context: WP8 Cg_wp8 is too small for reflection calculation.")
        else:
            print("Context: WP8 k_val_wp8 not solvable for reflection calculation.")

    # Effective start time for "ideal" window (Plot 2)
    effective_plot_start_time = steady_state_start_time # Default
    if wave_period_T and wave_period_T > 0:
        effective_plot_start_time = steady_state_start_time + (10 * wave_period_T)
        print(f"Plot 2 'ideal' window starts at {effective_plot_start_time:.2f}s (steady_state_start_time + 10*T).")
    else:
        print(f"Plot 2 'ideal' window starts at {effective_plot_start_time:.2f}s (using steady_state_start_time).")

    
    min_peak_height_param = 0.0005
    num_peaks_to_average_param = 5

    # --- Subplot 1: Raw and Filtered Signal ---
    plt.subplot(3, 1, 1)
    plt.plot(time_signal, wave_signal, label=f'Raw {channel_name}', color='blue', alpha=0.6)
    plt.plot(time_signal, filtered_signal_data, label=f'Filtered (Cutoff: {actual_cutoff_for_label:.2f}Hz)', color='red', linewidth=1.2)
    if not np.isnan(reflection_arrival_time_WP8_context):
        plt.axvline(reflection_arrival_time_WP8_context, color='darkcyan', linestyle=':', linewidth=1.5, label=f'WP8 Reflection Context ({reflection_arrival_time_WP8_context:.2f}s)')
    plt.title(f'1. Raw and Filtered Signal ({channel_name})')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (m)')
    plt.legend(fontsize='small')
    plt.grid(True, linestyle=':', alpha=0.5)

    # --- Subplot 2: WP8-Specific Ideal Sample Window ---
    plt.subplot(3, 1, 2)
    plot2_start_time = effective_plot_start_time
    plot2_end_time = reflection_arrival_time_WP8_context if not np.isnan(reflection_arrival_time_WP8_context) else time_signal[-1]
    
    # Plot full detrended filtered signal as background
    plt.plot(time_signal, detrended_filtered_signal, label='Full Detrended Signal', color='lightgrey', alpha=0.7, zorder=1)

    idx_start_plot2 = np.searchsorted(time_signal, plot2_start_time)
    idx_end_plot2 = np.searchsorted(time_signal, plot2_end_time)

    if idx_start_plot2 < idx_end_plot2 and idx_start_plot2 < len(detrended_filtered_signal):
        time_plot2 = time_signal[idx_start_plot2:idx_end_plot2]
        signal_plot2 = detrended_filtered_signal[idx_start_plot2:idx_end_plot2]
        if len(signal_plot2) >= 2:
            plt.plot(time_plot2, signal_plot2, label='Signal in WP8 Ideal Window', color='purple', linewidth=1.5, zorder=2)
            
            min_dist_p2 = int(fs * wave_period_T * 0.4) if (wave_period_T and wave_period_T > 0 and fs > 0) else int(fs * 0.1)
            if min_dist_p2 < 1: min_dist_p2 = 1
            peaks_idx_p2, _ = scipy.signal.find_peaks(signal_plot2, height=min_peak_height_param, distance=min_dist_p2)
            
            if len(peaks_idx_p2) > 0:
                peak_t_p2 = time_plot2[peaks_idx_p2]
                peak_a_p2 = signal_plot2[peaks_idx_p2]
                plt.plot(peak_t_p2, peak_a_p2, "x", color='orange', markersize=5, label='Peaks in Window')
                amps_to_avg_p2 = peak_a_p2[:num_peaks_to_average_param]
                if len(amps_to_avg_p2) > 0:
                    avg_p2 = np.mean(amps_to_avg_p2)
                    print(f"Plot 2: Avg for WP8 ideal window ({plot2_start_time:.2f}s to {plot2_end_time:.2f}s): {avg_p2:.6f} m")
                    plt.plot(peak_t_p2[:len(amps_to_avg_p2)], amps_to_avg_p2, "o", mfc='red', mec='k', ms=7, label=f'Avg Peaks ({avg_p2:.4f}m)', zorder=3)
            else: print("Plot 2: No peaks in WP8 ideal window.")
        else: print("Plot 2: WP8 ideal window signal too short.")
    else: print("Plot 2: WP8 ideal window invalid or empty.")

    plt.axvline(plot2_start_time, color='gray', linestyle='--', label=f'Ideal Start ({plot2_start_time:.2f}s)')
    if not np.isnan(reflection_arrival_time_WP8_context):
        plt.axvline(reflection_arrival_time_WP8_context, color='darkcyan', linestyle=':', label=f'WP8 Reflect End ({reflection_arrival_time_WP8_context:.2f}s)')
    plt.title('2. WP8-Specific Ideal Sample Window & Analysis')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (m)')
    plt.legend(fontsize='small')
    plt.grid(True, linestyle=':', alpha=0.5)

    # --- Subplot 3: Visualization of `calculate_average_steady_state_amplitude` for `channel_name` ---
    plt.subplot(3, 1, 3)
    
    plt.plot(time_signal, detrended_filtered_signal, label='Full Detrended Signal', color='lightgrey', alpha=0.7, zorder=1)

    # --- Replicate dynamic window logic for Plot 3 visualization ---
    plot3_start_t = -1
    plot3_end_t = -1
    
    # Calculate Input Amplitude for dynamic windowing (for plotting)
    calculated_input_amplitude_plot3 = np.nan
    if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0 and \
       steepness_val_from_filename is not None and isinstance(steepness_val_from_filename, (float, int)):
        k_val_ia_p3 = solve_dispersion_for_k(wave_period_T, water_depth, g)
        if k_val_ia_p3 is not None and not np.isnan(k_val_ia_p3) and k_val_ia_p3 > 0:
            wavelength_ia_p3 = 2 * np.pi / k_val_ia_p3
            actual_steepness_HL_ia_p3 = steepness_val_from_filename / 1000.0
            calculated_input_amplitude_plot3 = actual_steepness_HL_ia_p3 * wavelength_ia_p3 / 2.0
            plt.axhline(target_amplitude_factor_for_start * calculated_input_amplitude_plot3, color='cyan', linestyle=':', label=f'{int(target_amplitude_factor_for_start*100)}% Input Amp ({target_amplitude_factor_for_start * calculated_input_amplitude_plot3:.4f}m)')

    # Dynamic Start Time for Plot 3
    first_index_above_target_plot3 = -1
    peaks_after_threshold_indices_abs_plot3 = []
    if not np.isnan(calculated_input_amplitude_plot3) and calculated_input_amplitude_plot3 > 0:
        target_signal_level_plot3 = target_amplitude_factor_for_start * calculated_input_amplitude_plot3
        indices_above_target_plot3 = np.where(detrended_filtered_signal >= target_signal_level_plot3)[0]
        
        if len(indices_above_target_plot3) > 0:
            first_index_above_target_plot3 = indices_above_target_plot3[0]
            plt.plot(time_signal[first_index_above_target_plot3], detrended_filtered_signal[first_index_above_target_plot3], 'cD', ms=8, label='Threshold Met')

            min_peak_dist_p3_start = int(fs * wave_period_T * 0.4) if (wave_period_T and wave_period_T > 0 and fs > 0) else int(fs * 0.1)
            if min_peak_dist_p3_start < 1: min_peak_dist_p3_start = 1
            
            peaks_rel_p3, _ = scipy.signal.find_peaks(
                detrended_filtered_signal[first_index_above_target_plot3:], 
                height=min_peak_height_param, distance=min_peak_dist_p3_start)
            peaks_after_threshold_indices_abs_plot3 = peaks_rel_p3 + first_index_above_target_plot3
            
            if len(peaks_after_threshold_indices_abs_plot3) > num_peaks_skip_start:
                start_peak_abs_idx_p3 = peaks_after_threshold_indices_abs_plot3[num_peaks_skip_start]
                plot3_start_t = time_signal[start_peak_abs_idx_p3]
                # Plot skipped peaks
                for k_skip in range(num_peaks_skip_start):
                    if k_skip < len(peaks_after_threshold_indices_abs_plot3):
                         idx = peaks_after_threshold_indices_abs_plot3[k_skip]
                         plt.plot(time_signal[idx], detrended_filtered_signal[idx], 'x', color='silver', ms=6, label='Skipped Start Peak' if k_skip==0 else None)
            else: # Not enough peaks to skip
                 peaks_after_threshold_indices_abs_plot3 = [] # Clear to indicate failure for this path
    
    if plot3_start_t == -1: # Fallback start time for plot3
        if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0:
            plot3_start_t = initial_skip_periods * wave_period_T
        else:
            plot3_start_t = steady_state_start_time

    # Reflection time for current channel (Plot 3)
    reflection_time_current_chan_plot3 = np.nan
    if wave_period_T and wave_period_T > 0 and channel_name in wave_probe_positions_x:
        # ... (reflection calculation as before, assign to reflection_time_current_chan_plot3) ...
        k_val_curr = solve_dispersion_for_k(wave_period_T, water_depth, g)
        if k_val_curr and not np.isnan(k_val_curr) and k_val_curr > 0:
            omega_curr = 2 * np.pi / wave_period_T
            C_curr = omega_curr / k_val_curr
            n_curr = 0.5 * (1 + (2 * k_val_curr * water_depth) / np.sinh(2 * k_val_curr * water_depth))
            Cg_curr = n_curr * C_curr
            if Cg_curr > 1e-6:
                current_wp_x_pos = wave_probe_positions_x[channel_name]
                time_to_reach_curr_probe = current_wp_x_pos / Cg_curr
                time_curr_probe_to_wall_and_back = (2 * (L_tank - current_wp_x_pos)) / Cg_curr
                if time_curr_probe_to_wall_and_back >= 0:
                    reflection_time_current_chan_plot3 = time_to_reach_curr_probe + time_curr_probe_to_wall_and_back
                    if not np.isnan(reflection_time_current_chan_plot3): # Add delay if valid
                        reflection_time_current_chan_plot3 += reflection_time_delay_sec_plot

    # Dynamic End Time for Plot 3
    peaks_before_reflection_indices_abs_plot3 = []
    if not np.isnan(reflection_time_current_chan_plot3):
        idx_refl_p3 = np.searchsorted(time_signal, reflection_time_current_chan_plot3)
        if idx_refl_p3 > 0:
            min_peak_dist_p3_end = int(fs * wave_period_T * 0.4) if (wave_period_T and wave_period_T > 0 and fs > 0) else int(fs * 0.1)
            if min_peak_dist_p3_end < 1: min_peak_dist_p3_end = 1
            
            peaks_refl_rel_p3, _ = scipy.signal.find_peaks(
                detrended_filtered_signal[:idx_refl_p3], 
                height=min_peak_height_param, distance=min_peak_dist_p3_end)
            peaks_before_reflection_indices_abs_plot3 = peaks_refl_rel_p3 # These are already absolute relative to detrended_filtered_signal
            
            if len(peaks_before_reflection_indices_abs_plot3) > num_peaks_skip_end:
                end_peak_abs_idx_p3 = peaks_before_reflection_indices_abs_plot3[-(num_peaks_skip_end + 1)]
                plot3_end_t = time_signal[end_peak_abs_idx_p3]
                # Plot skipped end peaks
                for k_skip in range(num_peaks_skip_end):
                    if k_skip < len(peaks_before_reflection_indices_abs_plot3):
                        idx = peaks_before_reflection_indices_abs_plot3[-(k_skip + 1)]
                        plt.plot(time_signal[idx], detrended_filtered_signal[idx], '+', color='gray', ms=7, label='Skipped End Peak' if k_skip==0 else None)
            else: # Not enough peaks to skip
                peaks_before_reflection_indices_abs_plot3 = [] # Clear to indicate failure

    if plot3_end_t == -1: # Fallback end time for plot3
        fallback_end_t_p3 = time_signal[-1]
        if not np.isnan(reflection_time_current_chan_plot3):
            if wave_period_T is not None and isinstance(wave_period_T, (float, int)) and wave_period_T > 0:
                potential_end_t = reflection_time_current_chan_plot3 - (end_buffer_periods * wave_period_T)
                fallback_end_t_p3 = min(fallback_end_t_p3, potential_end_t)
            else:
                fallback_end_t_p3 = min(fallback_end_t_p3, reflection_time_current_chan_plot3)
        plot3_end_t = fallback_end_t_p3
    
    if plot3_end_t <= plot3_start_t:
        print(f"Plot 3 ({channel_name}): Invalid window for visualization. End time {plot3_end_t:.2f}s <= Start time {plot3_start_t:.2f}s.")
        idx_start_plot3, idx_end_plot3 = 0, 0
    else:
        idx_start_plot3 = np.searchsorted(time_signal, plot3_start_t)
        idx_end_plot3 = np.searchsorted(time_signal, plot3_end_t, side='right')

    print(f"Plot 3 ({channel_name}): Visualized analysis window from {plot3_start_t:.2f}s to {plot3_end_t:.2f}s.")

    if idx_start_plot3 < idx_end_plot3 and idx_start_plot3 < len(detrended_filtered_signal):
        time_plot3_viz = time_signal[idx_start_plot3:idx_end_plot3]
        signal_plot3_viz = detrended_filtered_signal[idx_start_plot3:idx_end_plot3]
        if len(signal_plot3_viz) >= 2:
            plt.plot(time_plot3_viz, signal_plot3_viz, label=f'Signal for {channel_name} Calc Window', color='green', linewidth=1.5, zorder=2)
            
            min_dist_p3_final = int(fs * wave_period_T * 0.4) if (wave_period_T and wave_period_T > 0 and fs > 0) else int(fs * 0.1)
            if min_dist_p3_final < 1: min_dist_p3_final = 1
            peaks_idx_p3_viz, _ = scipy.signal.find_peaks(signal_plot3_viz, height=min_peak_height_param, distance=min_dist_p3_final)

            if len(peaks_idx_p3_viz) > 0:
                peak_t_p3_viz = time_plot3_viz[peaks_idx_p3_viz]
                peak_a_p3_viz = signal_plot3_viz[peaks_idx_p3_viz]
                plt.plot(peak_t_p3_viz, peak_a_p3_viz, "o", mfc='darkorange', mec='k', ms=7, label=f'Peaks in Calc Window ({len(peak_a_p3_viz)} used)', zorder=3)
                avg_p3_viz = np.mean(peak_a_p3_viz)
                print(f"Plot 3: Avg for {channel_name} visualized window ({plot3_start_t:.2f}s to {plot3_end_t:.2f}s): {avg_p3_viz:.6f} m")
            else: print(f"Plot 3: No peaks in {channel_name} visualized calc window.")
        else: print(f"Plot 3: {channel_name} visualized calc window signal too short.")
    else: print(f"Plot 3: {channel_name} visualized calc window invalid or empty.")

    plt.axvline(plot3_start_t, color='dimgray', linestyle='--', label=f'Calc Start ({plot3_start_t:.2f}s)')
    if not np.isnan(reflection_time_current_chan_plot3):
        plt.axvline(reflection_time_current_chan_plot3, color='magenta', linestyle=':', label=f'{channel_name} Actual Reflect ({reflection_time_current_chan_plot3:.2f}s)')
    
    # Show calculated end time if it's meaningfully different from reflection or signal end
    is_end_meaningful = plot3_end_t < (time_signal[-1] - 0.1) # Not just the end of the signal
    if not np.isnan(reflection_time_current_chan_plot3):
        is_end_meaningful = is_end_meaningful and not math.isclose(plot3_end_t, reflection_time_current_chan_plot3)
    
    if is_end_meaningful :
        plt.axvline(plot3_end_t, color='darkgreen', linestyle='-.', label=f'Calc End ({plot3_end_t:.2f}s)')

    plt.title(f'3. Visualization of `calculate_average_steady_state_amplitude` for {channel_name}')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (m)')
    plt.legend(fontsize='small')
    plt.grid(True, linestyle=':', alpha=0.5)
    
    plt.tight_layout(pad=2.5, h_pad=3.0) # Adjust padding
    plt.show()

    # Call the main calculation function to compare its output (printed at the end of this test function)
    original_avg_amp = calculate_average_steady_state_amplitude(
        full_file_path, channel_name, wave_period_T, 
        steepness_val_from_filename, # New
        steady_state_start_time, 
        filter_order, cutoff_freq_hz,
        initial_skip_periods, end_buffer_periods,
        target_amplitude_factor_for_start, # Pass new params
        num_peaks_skip_start,
        num_peaks_skip_end,
        reflection_time_delay_sec=reflection_time_delay_sec_plot # Pass delay to calc function
    )
    print(f"\n--- Output from `calculate_average_steady_state_amplitude` for {channel_name} (for comparison) ---")
    print(f"  Average amplitude: {original_avg_amp:.6f} m" if original_avg_amp is not None and not np.isnan(original_avg_amp) else "  Average amplitude: NaN or None")


def plot_amplitude_vs_position_for_tests(df, num_tests_to_plot=5):
    """
    Creates a scatter plot of average amplitude vs. wave probe x-position
    for the first few tests, color-coded by test number.

    Args:
        df (pd.DataFrame): The master DataFrame containing wave probe amplitudes.
        num_tests_to_plot (int): The number of initial tests to plot (e.g., 5 for tests 1-5).
    """
    if df.empty:
        print("DataFrame is empty. Cannot generate scatter plot.")
        return
    if 'test_number' not in df.columns:
        print("Column 'test_number' not found in DataFrame. Cannot generate scatter plot.")
        return

    # Ensure test_number is numeric for filtering, coercing errors to NaN and dropping them
    df['test_number'] = pd.to_numeric(df['test_number'], errors='coerce')
    df_filtered = df.dropna(subset=['test_number'])
    df_filtered['test_number'] = df_filtered['test_number'].astype(int)

    # Filter for the first num_tests_to_plot tests
    tests_to_plot = sorted(df_filtered['test_number'].unique())[:num_tests_to_plot]
    
    if not tests_to_plot:
        print(f"No tests found within the first {num_tests_to_plot} unique test numbers after filtering.")
        return
        
    df_subset = df_filtered[df_filtered['test_number'].isin(tests_to_plot)]

    if df_subset.empty:
        print(f"No data available for tests {tests_to_plot}. Cannot generate scatter plot.")
        return

    plt.figure(figsize=(12, 8))
    
    # Define colors for the tests
    colors = plt.cm.get_cmap('viridis', len(tests_to_plot)) # Using viridis colormap

    for i, test_num in enumerate(tests_to_plot):
        test_data = df_subset[df_subset['test_number'] == test_num]
        if test_data.empty:
            continue

        x_positions = []
        y_amplitudes = []

        for wp_name, x_pos in wave_probe_positions_x.items():
            amp_col_name = f"{wp_name}_Avg_Amp"
            if amp_col_name in test_data.columns:
                # Assuming one row per file (and thus per test for this filtering)
                # If multiple files could have the same test_number, this might need averaging or selection
                amplitude = test_data[amp_col_name].values[0] # Get the amplitude for this WP
                if not np.isnan(amplitude):
                    x_positions.append(x_pos)
                    y_amplitudes.append(amplitude)
            
        if x_positions: # Only plot if there's data for this test
            plt.scatter(x_positions, y_amplitudes, color=colors(i), label=f'Test {test_num}', s=50, alpha=0.7)

    plt.xlabel('Wave Probe Position (m)')
    plt.ylabel('Average Amplitude (m)')
    plt.title(f'Amplitude vs. Wave Probe Position (First {len(tests_to_plot)} Tests)')
    plt.legend(title="Test Number")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.show()


def plot_error_ratio_summary(df):
    """
    Creates a scatter plot of the error_ratio (mean_wp_amplitude / inputAmplitude)
    vs. test_number.

    Args:
        df (pd.DataFrame): The master DataFrame, must include 'test_number', 
                           'inputAmplitude', and 'error_ratio' columns.
    """
    if df.empty:
        print("DataFrame is empty. Cannot generate error ratio plot.")
        return
    
    required_cols = ['test_number', 'error_ratio', 'inputAmplitude']
    for col in required_cols:
        if col not in df.columns:
            print(f"Required column '{col}' not found. Cannot generate error ratio plot.")
            return

    df_plot = df.copy()
    # Ensure 'test_number' and 'error_ratio' are numeric and drop NaNs for plotting
    df_plot['test_number'] = pd.to_numeric(df_plot['test_number'], errors='coerce')
    df_plot['error_ratio'] = pd.to_numeric(df_plot['error_ratio'], errors='coerce')
    df_plot.dropna(subset=['test_number', 'error_ratio'], inplace=True)

    if df_plot.empty:
        print("No valid data to plot for error ratio summary after NaN removal.")
        return

    df_plot.sort_values('test_number', inplace=True)

    plt.figure(figsize=(12, 8))
    scatter = plt.scatter(df_plot['test_number'], df_plot['error_ratio'], 
                          c=df_plot['inputAmplitude'], cmap='viridis', s=60, alpha=0.8,
                          label='Error Ratio per Test')
    
    plt.axhline(1.0, color='green', linestyle='--', linewidth=1, label='Ideal Ratio (1.0)')
    
    cbar = plt.colorbar(scatter, label='Input Amplitude (m)')
    plt.xlabel('Test Number')
    plt.ylabel('Error Ratio (Mean WP Amplitude / Input Amplitude)')
    plt.title('Error Ratio of Measured vs. Input Amplitude (Color-coded by Input Amplitude)')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.6)
    plt.show()


def plot_all_wps_raw_stacked(full_file_path, vertical_offset_scale=1.0):
    """
    Plots the raw signals for all wave probes (WP3-WP8) from a single file,
    stacked vertically on the same subplot for comparison of wave arrival times.

    Args:
        full_file_path (str): The full path to the .bin data file.
        vertical_offset_scale (float): Factor to scale the automatic vertical offset.
                                       Increase for more separation, decrease for less.
    """
    print(f"\n--- Plotting all WPs raw stacked from {os.path.basename(full_file_path)} ---")
    try:
        df = bin_to_dataframe(full_file_path)
    except Exception as e:
        print(f"Error loading data for stacked plot: {e}")
        return

    if "Time" not in df.columns:
        print("Time column not found in DataFrame for stacked plot.")
        return

    time_signal = df["Time"].values
    
    plt.figure(figsize=(15, 10))
    
    # Determine a base offset. Use max peak-to-peak of first few seconds of WP3 as a heuristic.
    # This is a rough guide; manual adjustment of vertical_offset_scale might be needed.
    base_offset = 0.01 # Default small offset
    if "WP3" in df.columns and len(df["WP3"]) > 100:
        initial_signal_wp3 = df["WP3"].values[:min(len(df["WP3"]), int(1000))] # Look at first ~1000 points
        if len(initial_signal_wp3) > 1:
            ptp = np.ptp(initial_signal_wp3)
            if ptp > 1e-5 : # Only use if there's some variation
                 base_offset = ptp * vertical_offset_scale
    
    current_offset = 0
    
    # Define a consistent order for plotting if desired, e.g., by probe number
    wp_channels_to_plot = sorted([wp for wp in wave_probe_positions_x.keys() if wp in df.columns], 
                                 key=lambda x: int(x[2:])) # Sort by number in WP_X_

    for i, channel_name in enumerate(wp_channels_to_plot):
        if channel_name in df.columns:
            wave_signal = df[channel_name].values
            
            # Apply offset for stacking
            # The first plot is at 0, next is offset, next is 2*offset
            # Or, make them descend: 0, -offset, -2*offset
            offset_to_apply = -i * base_offset 
            
            plt.plot(time_signal, wave_signal + offset_to_apply, label=f'{channel_name} (offset: {offset_to_apply:.2f}m)')
            # Optionally, add a text label next to the start of each trace
            if len(time_signal) > 0 and len(wave_signal) > 0:
                 plt.text(time_signal[0] - (time_signal[-1]-time_signal[0])*0.05, # Position slightly to the left
                          wave_signal[0] + offset_to_apply, 
                          channel_name, 
                          verticalalignment='center')
        else:
            print(f"Channel {channel_name} not found in data for stacked plot.")
            
    plt.title(f'Raw Wave Probe Signals (Stacked) - {os.path.basename(full_file_path)}')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (m) + Vertical Offset')
    # Adjust legend position if plt.text is used for labels, or remove legend if text is sufficient
    plt.legend(fontsize='small', loc='upper right') 
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.tight_layout()
    plt.show()


# Example usage (optional, can be called from another script or main block)
if __name__ == '__main__':
    # Ensure the path is correct and files exist there for testing
    # wavedocumentation_path = r'C:\path\to\your\wave_documentation_folder' # Override if needed for direct testing
    
    # Check if the path exists before trying to create the dataframe
    if os.path.exists(wavedocumentation_path) and os.path.isdir(wavedocumentation_path):
        print(f"Attempting to create DataFrame from data in: {wavedocumentation_path}")
        master_df = create_dataframe()
        if not master_df.empty:
            print("\n--- DataFrame Head (Initial) ---")
            print(master_df.head(2))
            
            # Calculate mean_wp_amplitude and error_ratio
            wp_amp_cols = [f"{wp_name}_Avg_Amp" for wp_name in wave_probe_positions_x.keys()]
            existing_wp_amp_cols = [col for col in wp_amp_cols if col in master_df.columns]

            if existing_wp_amp_cols:
                # Ensure WP amplitude columns are numeric before calculating mean
                for col in existing_wp_amp_cols:
                    master_df[col] = pd.to_numeric(master_df[col], errors='coerce')
                
                master_df['mean_wp_amplitude'] = master_df[existing_wp_amp_cols].mean(axis=1, skipna=True)
                
                # Ensure inputAmplitude is numeric
                master_df['inputAmplitude'] = pd.to_numeric(master_df['inputAmplitude'], errors='coerce')
                
                # Calculate error_ratio, handling potential division by zero or NaN inputAmplitude
                master_df['error_ratio'] = np.where(
                    (master_df['inputAmplitude'].notna()) & (master_df['inputAmplitude'] != 0),
                    master_df['mean_wp_amplitude'] / master_df['inputAmplitude'],
                    np.nan  # Set to NaN if inputAmplitude is NaN or zero
                )
                print("\n--- DataFrame Head (with error_ratio) ---")
                print(master_df[['filename', 'inputAmplitude', 'mean_wp_amplitude', 'error_ratio']].head(2))
            else:
                print("No wave probe amplitude columns found to calculate mean_wp_amplitude and error_ratio.")
                master_df['mean_wp_amplitude'] = np.nan
                master_df['error_ratio'] = np.nan


            # Call the scatter plot function for amplitude vs position
            plot_amplitude_vs_position_for_tests(master_df.copy(), num_tests_to_plot=5)

            # Call the new error ratio plot function
            plot_error_ratio_summary(master_df.copy())

         
            # files_in_dir = [f for f in os.listdir(wavedocumentation_path) if f.endswith(".bin")]
            # if files_in_dir:
            #     # test_file_name = files_in_dir[0] # Take the first .bin file
            #     test_file_name = "1701-0_8-40(WaveDoc)#1.bin" # Specific file for testing reflection
            #     # test_file_name = "1701-1_0-40(WaveDoc)#1.bin" # Another file for variety
            #     test_full_path = os.path.join(wavedocumentation_path, test_file_name)
                
            #     if os.path.exists(test_full_path):
            #         try:
            #             test_period = extract_variable_from_filename(test_file_name, "period")
            #             if not isinstance(test_period, (float, int)): test_period = None 
            #             test_steepness_val = extract_variable_from_filename(test_file_name, "steepness")
            #             if not isinstance(test_steepness_val, (float, int)): test_steepness_val = None
            #         except ValueError:
            #             test_period = None 
            #             test_steepness_val = None
                   
            #         # --- Test the individual channel plotting function for ALL WPs ---
            #         # try:
            #         #     temp_df_for_wp_check = bin_to_dataframe(test_full_path)
            #         #     available_wps_in_file = [wp for wp in wave_probe_positions_x.keys() if wp in temp_df_for_wp_check.columns]
            #         # except Exception as e:
            #         #     print(f"Could not load {test_file_name} to check available WPs: {e}")
            #         #     available_wps_in_file = []

            #         # if not available_wps_in_file:
            #         #     print(f"No wave probes found or could be loaded from {test_file_name} for individual plotting.")
            #         # else:
            #         #     print(f"\nFound WPs in {test_file_name}: {available_wps_in_file}. Generating individual plots...")

            #         # for wp_to_plot in available_wps_in_file:
            #         #     print(f"\n--- Initiating visual test for {wp_to_plot} in {test_file_name} (Period: {test_period}) ---")
            #         #     test_plot_average_steady_state_amplitude(
            #         #         test_full_path, 
            #         #         wp_to_plot, 
            #         #         wave_period_T=test_period,
            #         #         steepness_val_from_filename=test_steepness_val, 
            #         #         steady_state_start_time=20.0, 
            #         #         filter_order=4,
            #         #         cutoff_freq_hz=5.0,
            #         #         initial_skip_periods=15, 
            #         #         end_buffer_periods=10,
            #         #     )

            #         # --- Test the stacked raw plot function (called once for the file) ---
            #         # if available_wps_in_file: 
            #         #     print(f"\n--- Initiating stacked raw plot for all WPs in {test_file_name} ---")
            #         #     plot_all_wps_raw_stacked(test_full_path, vertical_offset_scale=1.5) 

            #     else:
            #         print(f"Test file {test_file_name} not found at {test_full_path}")
            # else:
            #     print("No .bin files found in the directory to run the visual test.")

        else:
            print("Resulting DataFrame is empty.")
    else:
        print(f"Error: The specified wavedocumentation_path does not exist or is not a directory: {wavedocumentation_path}")
        print("Please ensure the path is correct and contains .bin files.")


