In [3]:
from obspy import read
from obspy import UTCDateTime
import pandas as pd
from pathlib import Path
import numpy as np
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

# Load the pick data
pick_data = pd.read_csv(
    "dataset_earthquakes/metadata.csv",
    usecols=['trace_name_original_1', 'trace_p_pick_time', 'trace_s_pick_time', 'source_sensor_distance']
)

def sanitize_filename(name):
    """Sanitize file names for consistent matching"""
    return str(name).replace(".", "").replace(":", "").replace("-", "").replace("_", "").strip().lower()

def calculate_snr(trace, time_point, pre_window=2, post_window=2):
    """
    Calculate SNR for a specific time point using pre and post windows
    """
    sampling_rate = trace.stats.sampling_rate
    
    # Convert time point to sample index
    point_index = int((time_point - trace.stats.starttime) * sampling_rate)
    
    # Calculate window indices
    noise_start = max(0, point_index - int(pre_window * sampling_rate))
    noise_end = point_index
    signal_start = point_index
    signal_end = min(len(trace.data), point_index + int(post_window * sampling_rate))
    
    # Extract windows
    noise_window = trace.data[noise_start:noise_end]
    signal_window = trace.data[signal_start:signal_end]
    
    # Calculate RMS values
    noise_rms = np.sqrt(np.mean(noise_window**2)) if len(noise_window) > 0 else 1e-10
    signal_rms = np.sqrt(np.mean(signal_window**2)) if len(signal_window) > 0 else 0
    
    return signal_rms / noise_rms

def calculate_snr_series(trace, pre_window=2, post_window=2):
    """
    Calculate SNR for entire trace using sliding windows
    """
    sampling_rate = trace.stats.sampling_rate
    window_samples = int((pre_window + post_window) * sampling_rate)
    
    # Calculate SNR at fewer points to improve performance
    step = max(1, int(sampling_rate / 10))  # Calculate SNR every 0.1 seconds
    times = []
    snr_values = []
    
    for i in range(0, len(trace.data) - window_samples, step):
        current_time = trace.stats.starttime + i/sampling_rate
        snr = calculate_snr(trace, current_time, pre_window, post_window)
        times.append(current_time)
        snr_values.append(snr)
    
    return np.array(times), np.array(snr_values)

def find_snr_peaks(times, snr_values, min_snr=1.7, min_distance_samples=10):
    """Find peaks in SNR values above threshold"""
    if len(snr_values) == 0:
        return np.array([]), np.array([])
    
    peaks, properties = find_peaks(snr_values, height=min_snr, distance=min_distance_samples)
    if len(peaks) == 0:
        return np.array([]), np.array([])
    
    return times[peaks], snr_values[peaks]

def validate_p_picks(trace, p_pick_time, time_tolerance=1.0, snr_threshold=1.7):
    """
    Validate P-picks by comparing with SNR peaks
    """
    # Calculate SNR series
    times, snr_values = calculate_snr_series(trace)
    
    # Find SNR peaks
    peak_times, peak_snrs = find_snr_peaks(times, snr_values, snr_threshold)
    
    # Calculate SNR at P-pick time
    p_pick_snr = calculate_snr(trace, p_pick_time)
    
    # Find nearest peak to P-pick
    if len(peak_times) > 0:
        time_diffs = np.abs([t.timestamp - p_pick_time.timestamp for t in peak_times])
        nearest_peak_idx = np.argmin(time_diffs)
        nearest_peak_time = peak_times[nearest_peak_idx]
        nearest_peak_snr = peak_snrs[nearest_peak_idx]
        
        # Check if within tolerance
        is_valid = time_diffs[nearest_peak_idx] <= time_tolerance
    else:
        is_valid = False
        nearest_peak_time = None
        nearest_peak_snr = None
    
    return is_valid, p_pick_snr, nearest_peak_time, nearest_peak_snr, times, snr_values, peak_times, peak_snrs

def plot_validation_results(trace, p_pick_time, times, snr_values, peak_times, peak_snrs):
    """Plot trace data, SNR values, and validation results"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot trace data
    trace_times = np.arange(len(trace.data)) / trace.stats.sampling_rate
    ax1.plot(trace_times, trace.data, 'b-', label='Trace Data')
    ax1.axvline(p_pick_time - trace.stats.starttime, color='r', linestyle='--', label='P-Pick')
    ax1.set_title('Seismic Trace')
    ax1.legend()
    
    # Plot SNR values
    if len(times) > 0:  # Only plot if we have SNR values
        ax2.plot(times - trace.stats.starttime, snr_values, 'g-', label='SNR')
        if len(peak_times) > 0:  # Only plot peaks if we found any
            ax2.scatter(peak_times - trace.stats.starttime, peak_snrs, color='r', label='SNR Peaks')
    ax2.axvline(p_pick_time - trace.stats.starttime, color='r', linestyle='--', label='P-Pick')
    ax2.axhline(2.0, color='k', linestyle=':', label='SNR Threshold')
    ax2.set_title('Signal-to-Noise Ratio')
    ax2.legend()
    
    plt.tight_layout()
    return fig

# Main processing loop
mseed_dir = "miniSEED_files"
results = []
total_files = 0
matching_files = 0

# Print all unique filenames in pick_data for debugging
print("Available traces in metadata:")
print(pick_data['trace_name_original_1'].apply(sanitize_filename).unique())

for file_path in Path(mseed_dir).glob("*.MSEED"):
    total_files += 1
    mseed_name = sanitize_filename(file_path.name)
    print(f"\nProcessing file: {file_path.name}")
    print(f"Sanitized name: {mseed_name}")
    
    # Find matching metadata
    matched_rows = pick_data[pick_data['trace_name_original_1'].apply(sanitize_filename) == mseed_name]
    
    if len(matched_rows) == 0:
        print(f"No match found for file: {file_path.name}")
        continue
    
    matching_files += 1
    matched_row = matched_rows.iloc[0]
    
    try:
        p_pick = UTCDateTime(matched_row['trace_p_pick_time'])
        stream = read(file_path)
        
        for trace in stream:
            print(f"Processing trace: {trace.id}")
            
            # Validate P-pick
            is_valid, p_pick_snr, nearest_peak_time, nearest_peak_snr, times, snr_values, peak_times, peak_snrs = validate_p_picks(trace, p_pick)
            
            # Plot results
            fig = plot_validation_results(trace, p_pick, times, snr_values, peak_times, peak_snrs)
            plt.savefig(f"validation_{trace.id}.png")
            plt.close()
            
            # Store results
            results.append({
                'trace_id': trace.id,
                'is_valid': is_valid,
                'p_pick_snr': p_pick_snr,
                'nearest_peak_time': nearest_peak_time,
                'nearest_peak_snr': nearest_peak_snr,
                'time_difference': None if nearest_peak_time is None else abs(nearest_peak_time - p_pick)
            })
            
    except Exception as e:
        print(f"Error processing {file_path.name}: {e}")
        continue

# Create summary DataFrame
results_df = pd.DataFrame(results)
print("\nProcessing Summary:")
print(f"Total files found: {total_files}")
print(f"Files matching metadata: {matching_files}")
print(f"Total traces processed: {len(results_df)}")
print(f"Valid P-picks: {results_df['is_valid'].sum()}")
if len(results_df) > 0:
    print(f"Average P-pick SNR: {results_df['p_pick_snr'].mean():.2f}")

Available traces in metadata:
['3469238120230415t131755541619zwspozas2dn1mseed'
 '3469238120230415t131755542715zwspozas3dn1mseed'
 '3469238120230415t131755540100zwspozas4dn1mseed'
 '3469238120230415t131755539436zwspozas5dn1mseed'
 '3469238120230415t131755539433zwspozas6dn1mseed'
 '3416151120230221t001234318032zwspozas2dn1mseed'
 '3416222120230221t010656700647zwspozas2dn1mseed'
 '3416250120230221t020238098562zwspozas2dn1mseed'
 '3416157120230221t001439010032zwspozas2dn1mseed'
 '3416683120230221t124834239517zwspozas2dn1mseed'
 '3416801120230221t154846241375zwspozas2dn1mseed'
 '3416989120230221t202555853184zwspozas2dn1mseed'
 '3416440120230221t071334331365zwspozas2dn1mseed'
 '3416783120230221t150801281277zwspozas2dn1mseed'
 '3416341120230221t044909161225zwspozas2dn1mseed'
 '3416305120230221t034155052397zwspozas2dn1mseed'
 '3416356120230221t053359899097zwspozas2dn1mseed'
 '3416202120230221t005656518032zwspozas2dn1mseed'
 '3416579120230221t103627752007zwspozas2dn1mseed'
 '3416162120230221t0