In [None]:
# -*- coding: utf-8 -*-
"""
Title: Data Preprocessing for Earthquake Magnitude Estimation

This notebook implements data preprocessing for the STEAD (STanford EArthquake Dataset),
focusing on preparing data for single-station earthquake magnitude estimation.

Key Steps:
1. Data loading and initial filtering
2. Signal quality control (SNR thresholding)
3. Extraction of maximum amplitude features
4. Statistical analysis and visualizations

Dependencies:
- torch, h5py, pandas, numpy
- matplotlib, seaborn
- tqdm
"""

# Part 1: Setup and Imports

In [None]:
#------------------------------------------------------------------------------
# Part 1: Setup and Imports
#------------------------------------------------------------------------------

import time
import json
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import h5py
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# Record start time
start_time = time.time()

# Configure environment
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True

# Part 2: Data Loading and Helper Functions

In [None]:
#------------------------------------------------------------------------------
# Part 2: Data Loading and Helper Functions
#------------------------------------------------------------------------------

class EarthquakeDataset(Dataset):
    """Custom dataset class for earthquake data."""
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def load_data(file_name, file_list):
    """
    Load earthquake data from HDF5 file.
    
    Args:
        file_name: Path to HDF5 file
        file_list: List of event IDs to load
        
    Returns:
        Tuple of (data tensor, labels tensor)
    """
    X = []
    Y = []
    dtfl = h5py.File(file_name, 'r')
    for evi in tqdm(file_list):
        dataset = dtfl.get(f'data/{evi}')
        if dataset is None:
            print(f"Dataset not found for event ID: {evi}")
            continue
        data = np.array(dataset)
        spt = int(dataset.attrs['p_arrival_sample'])
        dshort = data[spt-100:spt+2900, :]
        X.append(dshort)
        mag = round(float(dataset.attrs['source_magnitude']), 2)
        Y.append(mag)
    dtfl.close()
    return torch.tensor(np.array(X), dtype=torch.float32), torch.tensor(np.array(Y), dtype=torch.float32)

def string_convertor(dd):
    """Convert string-format SNR values to float list."""
    dd2 = dd.split()
    SNR = []
    for d in dd2:
        if d not in ['[', ']']:
            dL = d.split('[')
            dR = d.split(']')
            if len(dL) == 2:
                dig = dL[1]
            elif len(dR) == 2:
                dig = dR[0]
            elif len(dR) == 1 and len(dL) == 1:
                dig = d
            try:
                dig = float(dig)
            except Exception:
                dig = None
            SNR.append(dig)
    return SNR

# Part 3: Data Loading and Initial Processing

In [None]:
#------------------------------------------------------------------------------
# Part 3: Data Loading and Initial Processing
#------------------------------------------------------------------------------

# Define file paths
file_name = "merge.hdf5"  # Replace with your path
csv_file = "merge.csv"    # Replace with your path

# Verify file existence
assert os.path.isfile(file_name), f"HDF5 file not found at {file_name}"
assert os.path.isfile(csv_file), f"CSV file not found at {csv_file}"

# Load initial dataset
df = pd.read_csv(csv_file, low_memory=False)
print(f"Initial number of records: {len(df)}")

# Part 4: Initial Data Analysis and Visualization

In [None]:
#------------------------------------------------------------------------------
# Part 4: Initial Data Analysis and Visualization
#------------------------------------------------------------------------------

# Display initial value counts for trace categories
print("\nValue counts before filtering:")
trace_category_counts_before = df["trace_category"].value_counts()
print(trace_category_counts_before)

# Plot initial magnitude distribution
plt.figure(figsize=(10, 6))
sns.histplot(df["source_magnitude"], bins=30, kde=True)
plt.xlabel('Magnitude', fontweight='bold', fontsize=14)
plt.ylabel('Number', fontweight='bold', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
plt.grid(True)
max_magnitude = df["source_magnitude"].max()
min_magnitude = df["source_magnitude"].min()
plt.text(6.5, 120000, f'Max: {max_magnitude:.2f} M\nMin: {min_magnitude:.2f} M',
         bbox=dict(facecolor='none', edgecolor='red', boxstyle='round,pad=1'), 
         fontsize=14)
plt.tight_layout()
plt.show()

# Part 5: Data Filtering and Processing

In [None]:
#------------------------------------------------------------------------------
# Part 5: Data Filtering and Processing
#------------------------------------------------------------------------------

# Convert to datetime and sort chronologically
df['source_origin_time'] = pd.to_datetime(df['source_origin_time'])
df = df.sort_values(by='source_origin_time')

# Apply basic filters
print("\nApplying basic filters...")
df = df[df.trace_category == 'earthquake_local']
df = df[df.source_distance_km <= 110]
df = df[df.source_magnitude_type == 'ml']

# Apply sample-based filters
df = df[df.p_arrival_sample >= 200]
df = df[df.p_arrival_sample + 2900 <= 6000]
df = df[df.p_arrival_sample <= 1500]
df = df[df.s_arrival_sample >= 200]
df = df[df.s_arrival_sample <= 2500]

# Process coda end samples
print("\nProcessing coda end samples...")
df['coda_end_sample'] = df['coda_end_sample'].apply(lambda x: float(x.strip('[]')))
df = df.dropna(subset=['coda_end_sample'])
df = df[df['coda_end_sample'] <= 3000]

# Additional parameter filters
print("\nApplying parameter filters...")
df = df[df.p_travel_sec.notnull()]
df = df[df.p_travel_sec > 0]
df = df[df.source_distance_km.notnull()]
df = df[df.source_distance_km > 0]
df = df[df.source_depth_km.notnull()]
df = df[df.source_magnitude.notnull()]
df = df[df.back_azimuth_deg.notnull()]
df = df[df.back_azimuth_deg > 0]

# Process SNR
print("\nProcessing SNR values...")
df.snr_db = df.snr_db.apply(lambda x: np.mean(string_convertor(x)))
df = df[df.snr_db >= 20]

# Plot filtered magnitude distribution
plt.figure(figsize=(10, 6))
sns.histplot(df["source_magnitude"], bins=30, kde=True, color='brown')
plt.xlabel('Magnitude', fontweight='bold', fontsize=14)
plt.ylabel('Number', fontweight='bold', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
plt.grid(True)
max_magnitude = df["source_magnitude"].max()
min_magnitude = df["source_magnitude"].min()
plt.text(4, 35000, f'Max: {max_magnitude:.2f} M\nMin: {min_magnitude:.2f} M',
         bbox=dict(facecolor='none', edgecolor='red', boxstyle='round,pad=1'), 
         fontsize=14)
plt.tight_layout()
plt.show()

print(f"\nNumber of records after filtering: {len(df)}")

# Part 6: Multi-Observation Station Processing

In [None]:
#------------------------------------------------------------------------------
# Part 6: Multi-Observation Station Processing
#------------------------------------------------------------------------------

# Identify stations with multiple observations
print("\nProcessing multi-observation stations...")
uniq_ins = df.receiver_code.unique()

labM = []
for ii in range(0, len(uniq_ins)):
    station_count = sum(n == str(uniq_ins[ii]) for n in df.receiver_code)
    print(f"Station {str(uniq_ins[ii])}: {station_count} observations")
    if station_count >= 400:  # Threshold for multi-observations
        labM.append(str(uniq_ins[ii]))

print(f"\nNumber of multi-observation stations: {len(labM)}")

# Save the multi-observations list
multi_observations_path = "multi_observations.npy"
np.save(multi_observations_path, labM)

# Load the multi-observations file for verification
multi_observations = np.load(multi_observations_path)

# Filter dataset to include only multi-observations
ev_list = []
for index, row in df.iterrows():
    st = row['receiver_code']
    if st in multi_observations:
        ev_list.append(row['trace_name'])

# Verify events exist in HDF5 file
with h5py.File(file_name, 'r') as dtfl:
    available_event_ids = set([key.split('/')[-1] for key in dtfl['data'].keys()])

# Keep only valid events
valid_events = [event for event in ev_list if event in available_event_ids]
print(f"\nNumber of valid events: {len(valid_events)}")

# Part 7: Load Waveform Data and Extract Features

In [None]:
#------------------------------------------------------------------------------
# Part 7: Load Waveform Data and Extract Features
#------------------------------------------------------------------------------

# Load waveform data
print("\nLoading waveform data...")
data, labels = load_data(file_name, valid_events)
print(f"Data shape: {data.shape}, Labels shape: {labels.shape}")

# Function to extract max values
def extract_max_values(data):
    """Extract maximum absolute values from waveforms for each component."""
    return torch.max(torch.abs(data), dim=0)[0]

# Function to plot waveforms
def plot_waveform(data, p_arrival, s_arrival, coda_end, index, sampling_rate=100):
    """Plot three-component waveform with arrival times."""
    fig = plt.figure(figsize=(12, 8))
    data = data[:3000, :]
    time_axis = np.arange(0, data.shape[0]) / sampling_rate

    for i, component in enumerate(['E-W', 'N-S', 'Vertical']):
        ax = fig.add_subplot(3, 1, i+1)
        ax.plot(time_axis, data[:, i], 'k')
        ax.axvline(x=p_arrival / sampling_rate, color='b', linewidth=2, label='P-arrival')
        ax.axvline(x=s_arrival / sampling_rate, color='r', linewidth=2, label='S-arrival')
        ax.axvline(x=coda_end / sampling_rate, color='aqua', linewidth=2, label='Coda End')
        ax.set_ylabel('Amplitude counts', fontweight='bold', fontsize=14)
        ax.legend(loc='upper right', fontsize=14)
        ax.set_title(f'Waveform {index+1}: {component} Component', 
                    fontweight='bold', fontsize=14)
        ax.tick_params(axis='both', which='major', labelsize=14)

        if i == 2:
            ax.set_xlabel('Time (s)', fontweight='bold', fontsize=14)

    plt.tight_layout()
    plt.show()



# Part 8: Extract and Analyze Maximum Amplitudes

In [None]:
#------------------------------------------------------------------------------
# Part 8: Extract and Analyze Maximum Amplitudes
#------------------------------------------------------------------------------

# Plot example waveforms
print("\nPlotting example waveforms...")
with h5py.File(file_name, 'r') as dtfl:
    for i in range(5):  # Plot first 5 waveforms
        evi = valid_events[i]
        dataset = dtfl.get(f'data/{evi}')
        if dataset is None:
            continue
            
        data_np = np.array(dataset)
        p_arrival = dataset.attrs['p_arrival_sample']
        s_arrival = dataset.attrs['s_arrival_sample']
        coda_end = dataset.attrs['coda_end_sample']

        plot_waveform(data_np, p_arrival, s_arrival, coda_end, i)

        print(f"\nAttributes of event {i+1}:")
        for attr in dataset.attrs:
            print(f"{attr}: {dataset.attrs[attr]}")
        print("\n")

# Initialize lists for max values
print("\nExtracting maximum amplitudes...")
max_values = []
log_max_values = []

# Process each event
for i in tqdm(range(len(data)), desc="Processing events"):
    try:
        event_max = extract_max_values(data[i])
        max_values.append(event_max)
        log_event_max = torch.log10(event_max + 1e-10)  # Add small constant to avoid log(0)
        log_max_values.append(log_event_max)
    except Exception as e:
        print(f"Error processing event {i}: {e}")
        print(f"Shape of data[{i}]: {data[i].shape}")
        print(f"Type of event_max: {type(event_max)}")
        print(f"Value of event_max: {event_max}")
        break

# Convert lists to tensors
max_values = torch.stack(max_values)
log_max_values = torch.stack(log_max_values)

# Save the extracted values
np.save('max_values.npy', max_values.numpy())
np.save('log_max_values.npy', log_max_values.numpy())

print("\nExtraction complete. Results saved.")
print(f"Max values shape: {max_values.shape}")
print(f"Log max values shape: {log_max_values.shape}")

# Part 9: Statistical Analysis and Visualization

In [None]:
#------------------------------------------------------------------------------
# Part 9: Statistical Analysis and Visualization
#------------------------------------------------------------------------------

# Analyze vertical component (index 2)
vertical_index = 2

# Plot histograms of max values
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(max_values[:, vertical_index].numpy(), bins=50)
plt.xlabel('Stream Max Value', fontweight='bold', fontsize=14)
plt.ylabel('Number', fontweight='bold', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)

plt.subplot(1, 2, 2)
plt.hist(log_max_values[:, vertical_index].numpy(), bins=50)
plt.xlabel('Log Stream Max Value', fontweight='bold', fontsize=14)
plt.ylabel('Number', fontweight='bold', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)

plt.tight_layout()
plt.show()

# Correlation analysis
correlation = np.corrcoef(max_values[:, vertical_index].numpy(), 
                         labels.numpy())[0, 1]
correlation_log = np.corrcoef(log_max_values[:, vertical_index].numpy(), 
                            labels.numpy())[0, 1]

print("\nCorrelation Analysis:")
print(f"Correlation between vertical stream max amplitude and magnitude: {correlation:.4f}")
print(f"Correlation between vertical log stream max amplitude and magnitude: {correlation_log:.4f}")

# Scatter plots
plt.figure(figsize=(10, 6))
scatter = plt.scatter(labels.numpy(), max_values[:, vertical_index].numpy(), 
                     alpha=0.5, c=labels.numpy(), cmap='viridis')
plt.xlabel('Earthquake Magnitude', fontweight='bold', fontsize=14)
plt.ylabel('Stream Max Amplitude (Vertical)', fontweight='bold', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
cbar = plt.colorbar(scatter)
cbar.set_label('Density', fontweight='bold', fontsize=14)
cbar.ax.tick_params(labelsize=14)
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 6))
scatter_log = plt.scatter(labels.numpy(), log_max_values[:, vertical_index].numpy(), 
                         alpha=0.5, c=labels.numpy(), cmap='viridis')
plt.xlabel('Earthquake Magnitude', fontweight='bold', fontsize=14)
plt.ylabel('Log Stream Max Amplitude (Vertical)', fontweight='bold', fontsize=14)
plt.tick_params(axis='both', which='major', labelsize=14)
cbar = plt.colorbar(scatter_log)
cbar.set_label('Density', fontweight='bold', fontsize=14)
cbar.ax.tick_params(labelsize=14)
plt.tight_layout()
plt.show()

# Part 10: Save Final Processed Data

In [None]:
#------------------------------------------------------------------------------
# Part 10: Save Final Processed Data
#------------------------------------------------------------------------------

# Save the final preprocessed data
output_data_file = "pre_processed_data.npy"
output_labels_file = "pre_processed_labels.npy"

np.save(output_data_file, data.cpu().numpy())
np.save(output_labels_file, labels.cpu().numpy())

print(f"\nPre-processed Data saved to {output_data_file}")
print(f"Pre-processed Labels saved to {output_labels_file}")

# Final timing information
end_time = time.time()
elapsed_time = end_time - start_time
print(f"\nTotal execution time: {elapsed_time/60:.2f} minutes")