In [1]:
import pandas as pd
import numpy as np
import os
import glob
from tqdm import tqdm

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
# --- Configuration ---
# Based on observations of the provided CSV data (z_m values are typically small)
DETECTION_THRESHOLD_MULTIPLIER = 1.5  # Multiplier for STD to define peak detection threshold
EXIT_THRESHOLD_MULTIPLIER = 1.0       # Multiplier for STD to define wake end threshold
MIN_WAKE_DURATION_SECONDS = 0.5       # Minimum duration in seconds for a detected event to be considered a wake
MERGE_GAP_SECONDS = 1.5               # Maximum gap between two wake segments to merge them
BUFFER_SECONDS = 0.1                  # Small buffer to add at the start/end of detected segments

In [3]:
def load_and_preprocess_data(file_path):
    """
    Loads the CSV data and performs basic preprocessing.
    """
    try:
        df = pd.read_csv(file_path)
        # Ensure 't_s' is sorted, though typically it is in time series data
        df = df.sort_values(by='t_s').reset_index(drop=True)
        return df
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return None
    except Exception as e:
        print(f"Error loading or preprocessing data from {file_path}: {e}")
        return None

In [4]:
def get_ground_truth_wakes(df):
    """
    Extracts ground truth wake intervals from the 'wake_label' column.
    Returns a list of (start_time, end_time) tuples.
    """
    ground_truth_wakes = []
    in_wake = False
    wake_start_time = None

    for i, row in df.iterrows():
        t_s = row['t_s']
        wake_label = row['wake_label']

        if wake_label == 1 and not in_wake:
            wake_start_time = t_s
            in_wake = True
        elif wake_label == 0 and in_wake:
            wake_end_time = t_s
            ground_truth_wakes.append((wake_start_time, wake_end_time))
            in_wake = False
    
    # Handle case where wake extends to the end of the time series
    if in_wake:
        ground_truth_wakes.append((wake_start_time, df['t_s'].iloc[-1]))

    return ground_truth_wakes

In [5]:
def detect_wakes_rule_based(df):
    """
    Detects wakes using a simple rule-based method based on z_m.
    Returns a list of (start_time, end_time) tuples for detected wakes.
    """
    z_m_abs = df['z_m'].abs()
    
    # Calculate Mean Absolute Deviation and Standard Deviation of z_m_abs for dynamic thresholds
    mean_abs_z = z_m_abs.mean()
    std_abs_z = z_m_abs.std()

    if std_abs_z == 0: # Avoid division by zero if z_m is constant
        # print("Warning: Standard deviation of z_m is zero. Cannot apply dynamic thresholds.")
        return []

    # Define thresholds
    detection_threshold = mean_abs_z + DETECTION_THRESHOLD_MULTIPLIER * std_abs_z
    exit_threshold = mean_abs_z + EXIT_THRESHOLD_MULTIPLIER * std_abs_z

    detected_segments = []
    in_segment = False
    segment_start_idx = None
    
    # Calculate time step for buffering
    time_diffs = df['t_s'].diff().dropna().unique()
    if len(time_diffs) > 0:
        avg_time_step = np.mean(time_diffs)
    else:
        avg_time_step = 0.001 # Default small time step if no diffs (e.g., single point)

    buffer_points = int(BUFFER_SECONDS / avg_time_step) if avg_time_step > 0 else 0
    min_wake_points = int(MIN_WAKE_DURATION_SECONDS / avg_time_step) if avg_time_step > 0 else 0
    merge_gap_points = int(MERGE_GAP_SECONDS / avg_time_step) if avg_time_step > 0 else 0

    # Step 1: Identify points above the detection threshold
    is_above_detection_threshold = z_m_abs > detection_threshold

    # Step 2: Identify continuous segments using a state machine
    for i in range(len(df)):
        current_t_s = df['t_s'].iloc[i]
        
        if is_above_detection_threshold.iloc[i] and not in_segment:
            segment_start_idx = i
            in_segment = True
        elif not is_above_detection_threshold.iloc[i] and in_segment:
            # Simple check: if current value is below exit threshold, mark segment end
            if z_m_abs.iloc[i] < exit_threshold:
                segment_end_idx = i
                # Apply buffer
                actual_start_idx = max(0, segment_start_idx - buffer_points)
                actual_end_idx = min(len(df) - 1, segment_end_idx + buffer_points)
                
                detected_segments.append((actual_start_idx, actual_end_idx))
                in_segment = False
    
    # Handle case where wake extends to the end of the time series
    if in_segment:
        actual_start_idx = max(0, segment_start_idx - buffer_points)
        actual_end_idx = len(df) - 1 # Extends to the end
        detected_segments.append((actual_start_idx, actual_end_idx))

    # Convert detected segments from indices to time and merge close segments
    predicted_wakes_time = []
    if detected_segments:
        # Initial conversion to time values
        current_wake_start_time = df['t_s'].iloc[detected_segments[0][0]]
        current_wake_end_time = df['t_s'].iloc[detected_segments[0][1]]

        for i in range(1, len(detected_segments)):
            next_segment_start_time = df['t_s'].iloc[detected_segments[i][0]]
            next_segment_end_time = df['t_s'].iloc[detected_segments[i][1]]

            # Check if the gap between current and next segment is small enough to merge
            if (next_segment_start_time - current_wake_end_time) <= MERGE_GAP_SECONDS:
                current_wake_end_time = max(current_wake_end_time, next_segment_end_time)
            else:
                predicted_wakes_time.append((current_wake_start_time, current_wake_end_time))
                current_wake_start_time = next_segment_start_time
                current_wake_end_time = next_segment_end_time
        
        predicted_wakes_time.append((current_wake_start_time, current_wake_end_time))
    
    # Filter out very short wakes
    final_predicted_wakes = []
    for start, end in predicted_wakes_time:
        if (end - start) >= MIN_WAKE_DURATION_SECONDS:
            final_predicted_wakes.append((start, end))

    return final_predicted_wakes

In [6]:
def assign_predicted_labels_to_df(df, predicted_wakes_intervals):
    """
    Adds a 'predicted_wake_label' column to the DataFrame based on detected wake intervals.
    """
    # Initialize all predicted_wake_label to 0 (no wake)
    df['predicted_wake_label'] = 0

    # Iterate through detected intervals and set label to 1 for timestamps within them
    for start_time, end_time in predicted_wakes_intervals:
        # Find indices where t_s is within the detected wake interval
        # Using vectorized operation for efficiency
        mask = (df['t_s'] >= start_time) & (df['t_s'] < end_time)
        df.loc[mask, 'predicted_wake_label'] = 1
    return df

In [7]:
def calculate_iou(interval1, interval2):
    """
    Calculates the Intersection over Union (IoU) of two time intervals.
    Intervals are (start, end) tuples.
    """
    start1, end1 = interval1
    start2, end2 = interval2

    # Calculate intersection
    intersection_start = max(start1, start2)
    intersection_end = min(end1, end2)
    
    intersection_duration = max(0, intersection_end - intersection_start)

    # Calculate union
    union_duration = (max(end1, end2) - min(start1, start2))

    if union_duration == 0:
        return 0.0 # No union, no overlap
    
    return intersection_duration / union_duration

In [8]:
def evaluate_detection(ground_truth_wakes, predicted_wakes, iou_threshold=0.5):
    """
    Evaluates the detection performance based on IoU.
    Returns TP, FP, FN, Accuracy, Precision, Recall, F1-Score.
    """
    tp = 0
    fp = 0
    fn = 0

    # Keep track of ground truth wakes that have been matched
    matched_gt_indices = set()
    # Keep track of predicted wakes that have been used as TP
    used_pred_indices = set()

    # Iterate through predicted wakes to find matches
    for pred_idx, pred_wake in enumerate(predicted_wakes):
        # Find the best ground truth match for the current predicted wake
        best_iou_for_pred = 0.0
        potential_gt_idx_for_pred = -1

        for gt_idx, gt_wake in enumerate(ground_truth_wakes):
            if gt_idx in matched_gt_indices: # Ground truth already matched, skip
                continue

            iou = calculate_iou(gt_wake, pred_wake)
            if iou > best_iou_for_pred:
                best_iou_for_pred = iou
                potential_gt_idx_for_pred = gt_idx
        
        # If a predicted wake has a good enough overlap with an unmatched ground truth wake
        if best_iou_for_pred >= iou_threshold and potential_gt_idx_for_pred != -1:
            tp += 1
            matched_gt_indices.add(potential_gt_idx_for_pred)
            used_pred_indices.add(pred_idx) # Mark this predicted wake as used for a TP
        else:
            fp += 1 # This predicted wake didn't match any unmatched ground truth wake sufficiently

    # Calculate False Negatives: ground truth wakes that were not matched by any predicted wake
    fn = len(ground_truth_wakes) - len(matched_gt_indices)

    total_ground_truth = len(ground_truth_wakes)
    
    # Calculate metrics
    # Accuracy here refers to the proportion of correctly detected ground truth wakes
    accuracy = tp / total_ground_truth if total_ground_truth > 0 else 0.0
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    
    if (precision + recall) == 0:
        f1_score = 0.0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    
    return tp, fp, fn, accuracy, precision, recall, f1_score

In [9]:
if __name__ == "__main__":
    base_data_dir = 'processed_ts' 
    output_base_dir = 'detected_wakes_rule_based2' 

    dataset_splits = {
        'train': [], 
        'valid': [], 
        'test': []   
    }

    # Dynamically find files in your actual directory structure
    for split_name in dataset_splits.keys():
        split_path = os.path.join(base_data_dir, split_name)
        if os.path.exists(split_path):
            dataset_splits[split_name] = glob.glob(os.path.join(split_path, '*.csv'))
        else:
            print(f"Warning: Directory '{split_path}' not found. Skipping {split_name} split.")
            dataset_splits[split_name] = [] # Ensure it's an empty list if directory doesn't exist


    print("--- Starting Dataset Evaluation and Prediction Saving ---")

    overall_ground_truth_wakes = []
    overall_predicted_wakes = []

    for split_name, file_paths in dataset_splits.items():
        if not file_paths:
            print(f"\nNo files found for '{split_name}' split. Skipping evaluation for this split.")
            continue

        print(f"\n--- Processing {split_name.upper()} Set ({len(file_paths)} files) ---")
        
        split_ground_truth_wakes = []
        split_predicted_wakes = []

        # Wrap the file_paths iteration with tqdm for a progress bar
        for file_path in tqdm(file_paths, desc=f"Processing {split_name} files"):
            df = load_and_preprocess_data(file_path)
            if df is not None:
                gt_wakes = get_ground_truth_wakes(df)
                pred_wake_intervals = detect_wakes_rule_based(df)
                
                df_with_predictions = assign_predicted_labels_to_df(df.copy(), pred_wake_intervals) 
                
                split_ground_truth_wakes.extend(gt_wakes)
                split_predicted_wakes.extend(pred_wake_intervals) 
                
                relative_path = os.path.relpath(file_path, base_data_dir)
                output_file_path = os.path.join(output_base_dir, relative_path)
                
                os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
                
                df_with_predictions.to_csv(output_file_path, index=False)
            else:
                # Print errors for specific files outside of tqdm to not mess up progress bar
                print(f"\n  Skipping processing for file: {file_path} due to load error.")

        if split_ground_truth_wakes or split_predicted_wakes:
            tp, fp, fn, accuracy, precision, recall, f1_score = evaluate_detection(
                split_ground_truth_wakes, split_predicted_wakes
            )

            print(f"\n--- {split_name.upper()} Evaluation Results ---")
            print(f"  Total Ground Truth Wakes: {len(split_ground_truth_wakes)}")
            print(f"  Total Predicted Wakes: {len(split_predicted_wakes)}")
            print(f"  True Positives (TP): {tp}")
            print(f"  False Positives (FP): {fp}")
            print(f"  False Negatives (FN): {fn}")
            print(f"  Accuracy (TP / Total GT): {accuracy:.4f}")
            print(f"  Precision: {precision:.4f}")
            print(f"  Recall: {recall:.4f}")
            print(f"  F1-Score: {f1_score:.4f}")
            
            overall_ground_truth_wakes.extend(split_ground_truth_wakes)
            overall_predicted_wakes.extend(split_predicted_wakes)

        else:
            print(f"  No ground truth or predicted wakes found in {split_name} set.")

    # Overall evaluation
    if overall_ground_truth_wakes or overall_predicted_wakes:
        tp_overall, fp_overall, fn_overall, accuracy_overall, precision_overall, recall_overall, f1_score_overall = evaluate_detection(
            overall_ground_truth_wakes, overall_predicted_wakes
        )

        print("\n--- Overall Dataset Evaluation Results ---")
        print(f"Total Ground Truth Wakes: {len(overall_ground_truth_wakes)}")
        print(f"Total Predicted Wakes: {len(overall_predicted_wakes)}")
        print(f"True Positives (TP): {tp_overall}")
        print(f"False Positives (FP): {fp_overall}")
        print(f"False Negatives (FN): {fn_overall}")
        print(f"Accuracy (TP / Total GT): {accuracy_overall:.4f}")
        print(f"Precision: {precision_overall:.4f}")
        print(f"Recall: {recall_overall:.4f}")
        print(f"F1-Score: {f1_score_overall:.4f}")
    else:
        print("\nNo ground truth or predicted wakes found across the entire dataset for overall evaluation.")


--- Starting Dataset Evaluation and Prediction Saving ---

--- Processing TRAIN Set (9909 files) ---


Processing train files: 100%|██████████| 9909/9909 [27:30<00:00,  6.00it/s]



--- TRAIN Evaluation Results ---
  Total Ground Truth Wakes: 9672
  Total Predicted Wakes: 303561
  True Positives (TP): 1799
  False Positives (FP): 301762
  False Negatives (FN): 7873
  Accuracy (TP / Total GT): 0.1860
  Precision: 0.0059
  Recall: 0.1860
  F1-Score: 0.0115

--- Processing VALID Set (3411 files) ---


Processing valid files: 100%|██████████| 3411/3411 [09:29<00:00,  5.99it/s]



--- VALID Evaluation Results ---
  Total Ground Truth Wakes: 3332
  Total Predicted Wakes: 104902
  True Positives (TP): 600
  False Positives (FP): 104302
  False Negatives (FN): 2732
  Accuracy (TP / Total GT): 0.1801
  Precision: 0.0057
  Recall: 0.1801
  F1-Score: 0.0111

--- Processing TEST Set (6120 files) ---


Processing test files: 100%|██████████| 6120/6120 [17:17<00:00,  5.90it/s]



--- TEST Evaluation Results ---
  Total Ground Truth Wakes: 3343
  Total Predicted Wakes: 212650
  True Positives (TP): 626
  False Positives (FP): 212024
  False Negatives (FN): 2717
  Accuracy (TP / Total GT): 0.1873
  Precision: 0.0029
  Recall: 0.1873
  F1-Score: 0.0058


KeyboardInterrupt: 

In [None]:
# Overall evaluation
if overall_ground_truth_wakes or overall_predicted_wakes:
    tp_overall, fp_overall, fn_overall, accuracy_overall, precision_overall, recall_overall, f1_score_overall = evaluate_detection(
        overall_ground_truth_wakes, overall_predicted_wakes
    )

    print("\n--- Overall Dataset Evaluation Results ---")
    print(f"Total Ground Truth Wakes: {len(overall_ground_truth_wakes)}")
    print(f"Total Predicted Wakes: {len(overall_predicted_wakes)}")
    print(f"True Positives (TP): {tp_overall}")
    print(f"False Positives (FP): {fp_overall}")
    print(f"False Negatives (FN): {fn_overall}")
    print(f"Accuracy (TP / Total GT): {accuracy_overall:.4f}")
    print(f"Precision: {precision_overall:.4f}")
    print(f"Recall: {recall_overall:.4f}")
    print(f"F1-Score: {f1_score_overall:.4f}")
else:
    print("\nNo ground truth or predicted wakes found across the entire dataset for overall evaluation.")


In [18]:
# plotting the results
import matplotlib.pyplot as plt
import random

In [20]:
output_base_dir = 'detected_wakes_rule_based2' # Directory where processed CSVs are saved
plots_output_dir = 'wake_plots'          # Directory where plots will be saved

# --- Plotting Function ---
def plot_wake_detection(df, file_name, save_dir=None):
    """
    Plots the time series data with ground truth and predicted wake labels.

    Args:
        df (pd.DataFrame): DataFrame containing 't_s', 'z_m', 'wake_label',
                           and 'predicted_wake_label' columns.
        file_name (str): The name of the original CSV file for the plot title.
        save_dir (str, optional): Directory to save the plots. If None, plots are shown.
    """
    plt.figure(figsize=(15, 6))
    
    # Plot the entire time series in black
    plt.plot(df['t_s'], df['z_m'], color='black', linewidth=0.8, label='z_m (Vertical Displacement)')

    # Highlight Ground Truth Wakes (Red Background)
    in_gt_wake = False
    gt_wake_start = None
    gt_label_added = False # Flag to add label only once
    for i, row in df.iterrows():
        if row['wake_label'] == 1 and not in_gt_wake:
            gt_wake_start = row['t_s']
            in_gt_wake = True
        elif row['wake_label'] == 0 and in_gt_wake:
            plt.axvspan(gt_wake_start, row['t_s'], color='red', alpha=0.3, label='Ground Truth Wake' if not gt_label_added else "")
            gt_label_added = True
            in_gt_wake = False
    # Handle case where GT wake extends to the end of the series
    if in_gt_wake:
        plt.axvspan(gt_wake_start, df['t_s'].iloc[-1], color='red', alpha=0.3, label='Ground Truth Wake' if not gt_label_added else "")
        gt_label_added = True

    # Highlight Predicted Wakes (Green Background)
    in_pred_wake = False
    pred_wake_start = None
    pred_label_added = False # Flag to add label only once
    for i, row in df.iterrows():
        if row['predicted_wake_label'] == 1 and not in_pred_wake:
            pred_wake_start = row['t_s']
            in_pred_wake = True
        elif row['predicted_wake_label'] == 0 and in_pred_wake:
            plt.axvspan(pred_wake_start, row['t_s'], color='green', alpha=0.3, label='Predicted Wake' if not pred_label_added else "")
            pred_label_added = True
            in_pred_wake = False
    # Handle case where Predicted wake extends to the end of the series
    if in_pred_wake:
        plt.axvspan(pred_wake_start, df['t_s'].iloc[-1], color='green', alpha=0.3, label='Predicted Wake' if not pred_label_added else "")
        pred_label_added = True

    plt.title(f'Wake Detection for {file_name}', fontsize=16)
    plt.xlabel('Time (s)', fontsize=12)
    plt.ylabel('z_m (m)', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout() # Adjust plot to prevent labels from overlapping

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        plot_path = os.path.join(save_dir, f'{os.path.basename(file_name).replace(".csv", "")}_wake_plot.png')
        plt.savefig(plot_path, dpi=300)
        plt.close() # Close the plot to free memory
    else:
        plt.show() # Display the plot

In [21]:
print("\n--- Generating Plots for 10 Random Validation Files ---")

valid_files_processed_path = os.path.join(output_base_dir, 'valid')
if not os.path.exists(valid_files_processed_path):
    print(f"Error: Validation output directory '{valid_files_processed_path}' not found. "
          "Please ensure your previous code has run and created this directory.")
else:
    validation_output_files = glob.glob(os.path.join(valid_files_processed_path, '*.csv'))
    
    if len(validation_output_files) == 0:
        print(f"No processed files found in '{valid_files_processed_path}'. Cannot generate plots.")
    else:
        # Select up to 10 random files, or all if less than 10
        num_plots_to_generate = min(10, len(validation_output_files))
        files_to_plot = random.sample(validation_output_files, num_plots_to_generate)
        
        os.makedirs(plots_output_dir, exist_ok=True) # Create directory for plots

        for file_path in tqdm(files_to_plot, desc="Generating plots"):
            try:
                df_to_plot = pd.read_csv(file_path)
                plot_file_name = os.path.basename(file_path)
                plot_wake_detection(df_to_plot, plot_file_name, save_dir=plots_output_dir)
            except Exception as e:
                print(f"Error plotting {file_path}: {e}")
        print(f"Plots saved to: {plots_output_dir}")


--- Generating Plots for 10 Random Validation Files ---


Generating plots: 100%|██████████| 10/10 [00:10<00:00,  1.05s/it]

Plots saved to: wake_plots



