In [None]:
import os

import numpy as np
import obspy
import pandas as pd
from obspy import UTCDateTime
from obspy.clients.fdsn import Client
from obspy.clients.fdsn.header import FDSNNoDataException
from tqdm import tqdm


def split_time_range(start_time, end_time, interval_days=365):
    """
    Splits a time range into smaller intervals.

    Parameters:
    ----------
    start_time : UTCDateTime
        Start of the time range.
    end_time : UTCDateTime
        End of the time range.
    interval_days : int, optional
        Length of each interval in days (default: 365).

    Returns:
    -------
    list of tuple
        List of (start, end) pairs for each interval.
    """
    intervals = []
    current_start = start_time
    while current_start < end_time:
        current_end = min(current_start + interval_days * 24 * 3600, end_time)
        intervals.append((current_start, current_end))
        current_start = current_end
    return intervals


def preprocess_phase(phase):
    """
    Preprocess phase hint to classify into P or S wave categories

    Parameters:
    phase: str, phase hint from picks

    Returns:
    str: "P" or "S" or None if unclassified
    """
    # Lowercase untuk konsistensi
    phase = phase.lower()

    # Daftar fase yang termasuk gelombang P
    p_phases = ["p", "pp", "pg", "pn", "pb", "pdiff", "pkp", "pkikp", "pp"]

    # Daftar fase yang termasuk gelombang S
    s_phases = ["s", "ss", "sg", "sn", "sb", "sdiff", "sks", "skiks", "ss"]

    if any(phase.startswith(p) for p in p_phases):
        return "P"
    elif any(phase.startswith(s) for s in s_phases):
        return "S"
    else:
        return None


def get_phase_picks(
    start_time, end_time, station_code=None, evaluation_mode="all", retry_count=3
):
    """
    Download phase picks with improved error handling and retries, with evaluation mode filtering.

    Parameters:
    ----------
    start_time : UTCDateTime
        Start time for retrieving data.
    end_time : UTCDateTime
        End time for retrieving data.
    station_code : str, optional
        Code of the seismic station to filter picks (default: None, includes all stations).
    evaluation_mode : str, optional
        Filter picks by evaluation mode. Can be "manual", "automatic", or "all" (default: "all").
    retry_count : int, optional
        Number of retry attempts in case of errors (default: 3).

    Returns:
    -------
    pd.DataFrame
        DataFrame containing picks information, filtered by the specified evaluation mode.
        Returns None if no data is available.
    """
    client = Client("GFZ")

    minlat = -11.0
    maxlat = 6.0
    minlon = 95.0
    maxlon = 141.0

    for attempt in range(retry_count):
        try:
            catalog = client.get_events(
                starttime=start_time,
                endtime=end_time,
                minlatitude=minlat,
                maxlatitude=maxlat,
                minlongitude=minlon,
                maxlongitude=maxlon,
                includearrivals=True,
            )

            picks_data = []
            for event in catalog:
                origin = event.preferred_origin() or event.origins[0]

                for pick in event.picks:
                    if station_code and pick.waveform_id.station_code != station_code:
                        continue

                    processed_phase = preprocess_phase(pick.phase_hint)
                    if processed_phase:  # Only save if phase is classified
                        if (
                            evaluation_mode != "all"
                            and pick.evaluation_mode != evaluation_mode
                        ):
                            continue  # Skip picks that don't match the specified mode

                        pick_data = {
                            "event_id": str(event.resource_id),
                            "event_time": origin.time.datetime,
                            "event_latitude": origin.latitude,
                            "event_longitude": origin.longitude,
                            "event_depth": origin.depth,
                            "station": pick.waveform_id.station_code,
                            "network": pick.waveform_id.network_code,
                            "channel": pick.waveform_id.channel_code,
                            "phase": processed_phase,  # Use processed phase
                            "original_phase": pick.phase_hint,  # Save original phase
                            "pick_time": pick.time.datetime,
                            "evaluation_mode": pick.evaluation_mode,
                            "evaluation_status": pick.evaluation_status,
                        }
                        picks_data.append(pick_data)

            return pd.DataFrame(picks_data)

        except Exception as e:
            if attempt < retry_count - 1:
                print(f"Attempt {attempt + 1} failed, retrying after 5 seconds...")
            else:
                print(f"Failed to get picks after {retry_count} attempts: {str(e)}")
                return None


def create_phase_labels(picks_df, waveform_stream, window_start, window_length):
    """
    Create phase labels for a waveform window

    Parameters:
    picks_df: DataFrame containing picks
    waveform_stream: ObsPy Stream object
    window_start: UTCDateTime, start time of window
    window_length: float, length of window in seconds

    Returns:
    numpy array of labels
    """
    window_end = window_start + window_length

    # Filter picks for this time window
    window_picks = picks_df[
        (picks_df.pick_time >= window_start) & (picks_df.pick_time <= window_end)
    ]

    # Create label array (sampling rate matches waveform)
    sampling_rate = waveform_stream[0].stats.sampling_rate
    n_samples = int(window_length * sampling_rate)

    # Initialize labels arrays (0 = no pick, 1 = pick)
    p_labels = np.zeros(n_samples)
    s_labels = np.zeros(n_samples)

    # Fill in picks
    for _, pick in window_picks.iterrows():
        # Calculate sample index for this pick
        pick_sample = int((pick.pick_time - window_start) * sampling_rate)

        if pick_sample >= 0 and pick_sample < n_samples:
            if pick.phase == "P":
                p_labels[pick_sample] = 1
            elif pick.phase == "S":
                s_labels[pick_sample] = 1

    return np.vstack([p_labels, s_labels])


def preprocess_waveform(event_stream, resample_rate=None):
    """
    Preprocess waveform data to ensure consistent shape across components and optional resampling.

    Parameters:
    ----------
    event_stream : ObsPy Stream object
        Waveform stream for an event.
    resample_rate : float, optional
        Target sampling rate in Hz for resampling (default: None, no resampling).

    Returns:
    -------
    numpy.ndarray
        Array of waveform data with shape (3, n_samples) or None if preprocessing fails.
    """
    try:
        # Check if all components have the same sampling rate
        sampling_rates = [tr.stats.sampling_rate for tr in event_stream]
        if len(set(sampling_rates)) != 1:
            print("Inconsistent sampling rates found")
            return None

        # Resample if a target rate is specified
        if resample_rate:
            event_stream.resample(sampling_rate=resample_rate)

        # Get lengths of all traces
        lengths = [len(tr.data) for tr in event_stream]
        if len(set(lengths)) != 1:
            print(f"Inconsistent trace lengths found: {lengths}")
            # Use shortest length
            min_length = min(lengths)
            for tr in event_stream:
                tr.data = tr.data[:min_length]

        # Convert to numpy array
        waveform_data = np.array([tr.data for tr in event_stream])

        # Check final shape
        if waveform_data.shape[0] != 3:
            print(f"Unexpected number of components: {waveform_data.shape[0]}")
            return None

        return waveform_data

    except Exception as e:
        print(f"Error preprocessing waveform: {str(e)}")
        return None


def prepare_dataset_with_labels(
    station_code,
    start_time,
    end_time,
    pre_event_time=30,
    window_length=120,
    evaluation_mode="all",
    resample_rate=None,
):
    """
    Prepares a dataset of waveform data and corresponding phase labels for a specific seismic station.

    This function retrieves seismic waveform data and phase picks from the GFZ (GEOFON) client
    for a given station and time range. The data is processed into a consistent format suitable
    for machine learning applications. Each event is represented by waveform data (`X`) and
    corresponding phase labels (`y`), indicating the presence of P and S phases within a specified
    time window.

    Parameters:
    ----------
    station_code : str
        Code of the seismic station (e.g., "BBJI").
    start_time : UTCDateTime
        Start time for retrieving data (inclusive).
    end_time : UTCDateTime
        End time for retrieving data (exclusive).
    pre_event_time : float, optional
        Time in seconds to include before the event time in the waveform window (default: 30).
    window_length : float, optional
        Total length of the waveform window in seconds (default: 120).
    evaluation_mode : str, optional
        Filter phase picks by evaluation mode (default: "all").
    resample_rate : float, optional
        Target sampling rate in Hz for resampling (default: None, no resampling).

    Returns:
    -------
    tuple
        A tuple `(X, y)` where:
        - `X` : numpy.ndarray
            Array of waveform data with shape `(n_events, 3, n_samples)`, where:
            - `n_events` : Number of valid seismic events.
            - `3` : Waveform components (Z, N, E).
            - `n_samples` : Number of samples in each waveform, determined by `window_length` and sampling rate.
        - `y` : numpy.ndarray
            Array of phase labels with shape `(n_events, 2, n_samples)`, where:
            - `2` : Label dimensions for P (index 0) and S (index 1) phases.
            - `n_samples` : Number of samples in each label array.
        Returns `(None, None)` if no valid waveform or label data is found.

    Example:
    --------
    >>> from obspy import UTCDateTime
    >>> X, y = prepare_dataset_with_labels(
    ...     station_code="BBJI",
    ...     start_time=UTCDateTime("2023-01-01"),
    ...     end_time=UTCDateTime("2023-01-02"),
    ...     pre_event_time=60,
    ...     window_length=600
    ... )
    >>> print(X.shape)  # (n_events, 3, n_samples)
    >>> print(y.shape)  # (n_events, 2, n_samples)

    Notes:
    ------
    - Waveform data is truncated to ensure consistency across components.
    - Labels are created by aligning phase picks (P and S) to the waveform samples within the time window.
    - The function handles errors such as missing data or inconsistent waveform lengths.

    """
    client = Client("GFZ")

    try:
        # First check if data is available
        st = client.get_waveforms(
            network="GE",
            station=station_code,
            location="*",
            channel="BH*",
            starttime=start_time,
            endtime=start_time + 3600,  # Check first hour only
            attach_response=False,
        )
    except FDSNNoDataException:
        print(f"No waveform data available for station {station_code}")
        return None, None
    except Exception as e:
        print(f"Error checking data availability for {station_code}: {str(e)}")
        return None, None

    # Get picks for this station
    picks_df = get_phase_picks(
        start_time=start_time,
        end_time=end_time,
        station_code=station_code,
        evaluation_mode=evaluation_mode,
    )

    if picks_df is None or len(picks_df) == 0:
        print(f"No picks found for station {station_code}")
        return None, None

    X = []
    y = []

    # Process each event
    for event_time in picks_df.event_time.unique():
        try:
            window_start = UTCDateTime(event_time) - pre_event_time
            window_end = window_start + window_length

            try:
                event_stream = client.get_waveforms(
                    network="GE",
                    station=station_code,
                    location="*",
                    channel="BH*",
                    starttime=window_start,
                    endtime=window_end,
                )

                if len(event_stream) == 3:
                    # Preprocess waveform data
                    waveform_data = preprocess_waveform(
                        event_stream, resample_rate=resample_rate
                    )
                    if waveform_data is not None:
                        event_picks = picks_df[picks_df.event_time == event_time]

                        # Create labels
                        n_samples = len(waveform_data[0])
                        labels = np.zeros((2, n_samples))

                        for _, pick in event_picks.iterrows():
                            pick_time = UTCDateTime(pick.pick_time)
                            if window_start <= pick_time <= window_end:
                                idx = int(
                                    (pick_time - window_start)
                                    * event_stream[0].stats.sampling_rate
                                )
                                if idx < n_samples:
                                    processed_phase = preprocess_phase(pick.phase)
                                    if processed_phase == "P":
                                        labels[0, idx] = 1
                                    elif processed_phase == "S":
                                        labels[1, idx] = 1

                        X.append(waveform_data)
                        y.append(labels)

            except FDSNNoDataException:
                continue

        except Exception as e:
            print(f"Error processing event at {event_time}: {str(e)}")
            continue

    if len(X) > 0:
        return np.array(X), np.array(y)
    else:
        return None, None


# Example usage
def get_ge_stations():
    """
    Returns dictionary of GE network stations in Indonesia
    """
    stations = {
        "BBJI": {"name": "Bungbulang, Garut, Java"},
        "CISI": {"name": "Cisomped, Java"},
        "GSI": {"name": "Gunungsitoli, Nias"},
        "BKB": {"name": "Balikpapan, Kalimantan"},
        "BKNI": {"name": "Bangkinang, Sumatra"},
        "BNDI": {"name": "Bandaneira, Indonesia"},
        "FAKI": {"name": "Fak Fak, Irian Jaya"},
        "GENI": {"name": "Genyem, Irian Jaya"},
        "JAGI": {"name": "Jajag, Java"},
        "LHMI": {"name": "Lhokseumawe, Sumatra"},
        "LUWI": {"name": "Luwuk, Sulawesi"},
        "MMRI": {"name": "Maumere, Flores"},
        "MNAI": {"name": "Manna, Sumatra"},
        "PLAI": {"name": "Plampang, Sumbawa"},
        "PMBI": {"name": "Palembang, Sumatra"},
        "PMBT": {"name": "Palembang, Sumatra"},
        "SANI": {"name": "Sanana, Moluccas"},
        "SAUI": {"name": "Saumlaki, Tanimbar"},
        "SMRI": {"name": "Semarang, Java"},
        "SOEI": {"name": "Soe, Timor"},
        "TNTI": {"name": "Ternate, Indonesia"},
        "TOLI": {"name": "Tolitoli, Sulawesi"},
        "TOLI2": {"name": "Tolitoli, Sulawesi"},
        "UGM": {"name": "Wanagama, Indonesia"},
        "YOGI": {"name": "Yogyakarta, Java"},
    }
    return stations


if __name__ == "__main__":
    # Set time period
    start_time = UTCDateTime("2015-01-01")
    end_time = UTCDateTime("2024-01-01")
    EVALUATION_MODE = "manual"

    # Create output directory
    output_dir = "seismic_data"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Split the time range into intervals
    intervals = split_time_range(start_time, end_time, interval_days=365)
    all_picks = []

    for interval_start, interval_end in intervals:
        print(f"Downloading phase picks from {interval_start} to {interval_end}...")
        picks_df = get_phase_picks(
            interval_start, interval_end, evaluation_mode=EVALUATION_MODE
        )

        if picks_df is not None and not picks_df.empty:
            print(
                f"Found {len(picks_df)} picks for interval {interval_start} to {interval_end}"
            )
            all_picks.append(picks_df)

    # Combine all picks into one DataFrame
    if all_picks:
        picks_df = pd.concat(all_picks, ignore_index=True)
        print(f"Total picks collected: {len(picks_df)}")

        # Save all picks to a CSV file
        picks_df.to_csv(os.path.join(output_dir, "all_picks.csv"), index=False)

        # Get list of stations
        stations = get_ge_stations()

        # Process each station
        for station_code in tqdm(stations.keys(), desc="Processing stations"):
            try:
                print(
                    f"\nProcessing station {station_code} ({stations[station_code]['name']})..."
                )

                # Prepare dataset for this station
                X, y = prepare_dataset_with_labels(
                    station_code=station_code,
                    start_time=start_time,
                    end_time=end_time,
                    pre_event_time=60,
                    window_length=600,
                    evaluation_mode=EVALUATION_MODE,
                    resample_rate=100.0,  # Resample ke 100 Hz
                )

                if X is not None and len(X) > 0:
                    # Create station-specific directory
                    station_dir = os.path.join(output_dir, station_code)
                    if not os.path.exists(station_dir):
                        os.makedirs(station_dir)

                    # Save datasets
                    print(f"Saving dataset with {len(X)} samples")
                    print(f"Waveform shape: {X.shape}")
                    print(f"Labels shape: {y.shape}")

                    np.save(os.path.join(station_dir, f"waveforms.npy"), X)
                    np.save(os.path.join(station_dir, f"labels.npy"), y)

                    # Save station-specific picks
                    station_picks = picks_df[picks_df.station == station_code]
                    station_picks.to_csv(
                        os.path.join(station_dir, f"picks.csv"), index=False
                    )

                else:
                    print(f"No valid data found for station {station_code}")

            except Exception as e:
                print(f"Error processing station {station_code}: {str(e)}")
                continue

        print("\nProcessing complete!")
        print(f"Data saved in directory: {output_dir}")
    else:
        print("No picks found for the specified time range.")

Downloading phase picks from 2015-01-01T00:00:00.000000Z to 2016-01-01T00:00:00.000000Z...


In [None]:
# find data with phase S in picks_df
picks_df[picks_df["phase"] == "S"]