In [None]:
import os
from collections import defaultdict
import time

import numpy as np
import rf
import rf.imaging
import matplotlib.pyplot as plt
import scipy
from scipy import signal
from scipy.signal import hilbert
from scipy.stats import moment
from scipy.interpolate import interp1d
import obspy
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

In [None]:
# Bring in interactive widgets capability. See https://towardsdatascience.com/interactive-controls-for-jupyter-notebooks-f5c94829aee6
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
# from beakerx import *

## Read source file

In [None]:
src_file = r"..\DATA\OA_event_waveforms_for_rf_20170911T000036-20181128T230620_LQT_td_rev3_qual.h5"

In [None]:
# oa_all = rf.read_rf(src_file, format='h5', group='/waveforms/OA.BT23.0M')
oa_all = rf.read_rf(src_file, format='h5')

In [None]:
# import copy
# oa_copy = copy.deepcopy(oa_all)

In [None]:
# oa_copy.moveout()

In [None]:
# plt.figure(figsize=(16,9))
# time_offset = oa_all[0].stats.onset - oa_all[0].stats.starttime
# plt.plot(oa_all[0].times() - time_offset, oa_all[0].data)
# time_offset = oa_copy[0].stats.onset - oa_copy[0].stats.starttime
# plt.plot(oa_copy[0].times() - time_offset, oa_copy[0].data, '--')
# plt.grid()
# plt.show()

## Define useful internal functions

In [None]:
def rf_to_dict(rf_data):
    db = defaultdict(lambda: defaultdict(list))
    for s in rf_data:
        _, sta, _, cha = s.id.split('.')
        db[sta][cha].append(s)
    return db

In [None]:
def plot_station_rf_overlays(db_station, title=None):
    num_channels = 0
    for ch, traces in db_station.items():
        if traces:
            num_channels += 1

    plt.figure(figsize=(16, 8*num_channels))
    colors = ["#8080a040", "#80a08040", "#a0808040"]
    min_x = 1e+20
    max_x = -1e20

    signal_means = []
    for i, (ch, traces) in enumerate(db_station.items()):
        if not traces:
            continue
        col = colors[i]
        plt.subplot(num_channels, 1, i + 1)
        sta = traces[0].stats.station
        for j, tr in enumerate(traces):
            lead_time = tr.stats.onset - tr.stats.starttime
            times = tr.times()
            plt.plot(times - lead_time, tr.data, '--', color=col, linewidth=2)
            mask = (~np.isnan(tr.data) & ~np.isinf(tr.data))
            if j == 0:
                data_mean = np.zeros_like(tr.data)
                data_mean[mask] = tr.data[mask]
                counts = mask.astype(np.float)
            else:
                data_mean[mask] += tr.data[mask]
                counts += mask.astype(np.float)
            # end if
        # end for
        data_mean = data_mean/counts
        data_mean[(counts == 0)] = np.nan
        signal_means.append(data_mean)
        plt.plot(tr.times() - lead_time, data_mean, color="#202020", linewidth=2)
        plt.xlabel('Time (s)')
        plt.ylabel('Amplitude (normalized)')
        plt.grid(linestyle=':', color="#80808020")
        title_text = '.'.join([sta, ch])
        if title is not None:
            title_text += ' ' + title
        plt.title(title_text, fontsize=14)
        x_lims = plt.xlim()
        min_x = min(min_x, x_lims[0])
        max_x = max(max_x, x_lims[1])
    # end for
    for i in range(num_channels):
        subfig = plt.subplot(num_channels, 1, i + 1)
        subfig.set_xlim((min_x, max_x))
    # end for
    
    return signal_means

In [None]:
def signed_nth_root(arr, order):
    if order == 1:
        return arr
    else:
        return np.sign(arr)*np.power(np.abs(arr), 1.0/order)

In [None]:
def signed_nth_power(arr, order):
    if order == 1:
        return arr
    else:
        return np.sign(arr)*np.power(np.abs(arr), order)

In [None]:
def compute_hk_stack(db_station, cha, h_range=np.linspace(10.0, 70.0, 301), k_range = np.linspace(1.3, 2.1, 201),
                     V_p = 6.4, root_order=1, include_t3=True):

    # Pre-compute grid quantities
    k_grid, h_grid = np.meshgrid(k_range, h_range)
    hk_stack = np.zeros_like(k_grid)
    H_on_V_p = h_grid/V_p
    k2 = k_grid*k_grid

    stream_stack = []
    cha_data = db_station[cha]
    # Loop over streams, compute times, and stack interpolated values at those times
    for s in cha_data:
        incidence = s.stats.inclination
        incidence_rad = incidence*np.pi/180.0
        cos_i, sin_i = np.cos(incidence_rad), np.sin(incidence_rad)
        sin2_i = sin_i*sin_i
        term1 = H_on_V_p*k_grid*np.abs(cos_i)
        term2 = H_on_V_p*np.sqrt(1 - k2*sin2_i)
        # Time for Ps
        t1 = term1 - term2
        # Time for PpPs
        t2 = term1 + term2
        if include_t3:
            # Time for PpSs + PsPs
            t3 = 2*term1

        # Subtract lead time so that primary P-wave arrival is at time zero.
        lead_time = s.stats.onset - s.stats.starttime
        times = s.times() - lead_time
        # Create interpolator from stream signal for accurate time sampling.
        interpolator = interp1d(times, s.data, kind='linear', copy=False, bounds_error=False, assume_sorted=True)

        phase_sum = []
        phase_sum.append(signed_nth_root(interpolator(t1), root_order))
        phase_sum.append(signed_nth_root(interpolator(t2), root_order))
        if include_t3:
            # Negative sign on the third term is intentional, see Chen et al. (2010) and Zhu & Kanamori (2000).
            # It needs to be negative because the PpSs + PsPs peak has negative phase,
            # see http://eqseis.geosc.psu.edu/~cammon/HTML/RftnDocs/rftn01.html
            # Apply nth root technique to reduce uncorrelated noise (Chen et al. (2010))
            phase_sum.append(-signed_nth_root(interpolator(t3), root_order))

        stream_stack.append(phase_sum)
    # end for

    # Perform the stacking (sum) across streams. hk_stack retains separate t1, t2, and t3 components here.
    hk_stack = np.nanmean(np.array(stream_stack), axis=0)

    # This inversion of the nth root is different to Sippl and Chen, but consistent with Muirhead
    # who proposed the nth root technique. It improves the contrast of the resulting plot.
    if root_order != 1:
        hk_stack = signed_nth_power(hk_stack, root_order)

    return k_grid, h_grid, hk_stack

In [None]:
def compute_weighted_stack(hk_components, weighting=(0.5, 0.5, 0.0)):
    assert hk_components.shape[0] == len(weighting), hk_components.shape
    hk_phase_stacked = np.dot(np.moveaxis(hk_components, 0, -1), np.array(weighting))
    return hk_phase_stacked

In [None]:
def plot_hk_stack(k_grid, h_grid, hk_stack, title=None, save_file=None, show=True, num=None, clip_negative=True):
    # Call computed_weighted_stack() first to combine weighted components before calling this function.
    # Use a perceptually linear color map.
    colmap = 'plasma'
    plt.figure(figsize=(16, 12))
    if clip_negative:
        hk_stack[hk_stack < 0] = 0
    plt.contourf(k_grid, h_grid, hk_stack, levels=50, cmap=colmap)
    cb = plt.colorbar()
    cb.ax.set_ylabel('Stack sum')
    plt.contour(k_grid, h_grid, hk_stack, levels=10, colors='k', linewidths=1)
    plt.xlabel(r'$\kappa = \frac{V_p}{V_s}$ (ratio)', fontsize=14)
    plt.ylabel('H = Moho depth (km)', fontsize=14)
    if title is not None:
        plt.title(title, fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    if num is not None:
        xl = plt.xlim()
        yl = plt.ylim()
        txt_x = xl[0] + 0.85*(xl[1] - xl[0])
        txt_y = yl[0] + 0.95*(yl[1] - yl[0])
        plt.text(txt_x, txt_y, "N = {}".format(num), color="#ffffff", fontsize=16, fontweight='bold')

    if save_file is not None:
        tries = 10
        while tries > 0:
            try:
                tries -= 1
                plt.savefig(save_file, dpi=300)
                break
            except PermissionError:
                time.sleep(1)
                if tries == 0:
                    print("WARNING: Failed to save file {} due to permissions!".format(save_file))
                    break
            # end try
        # end while

    if show:
        plt.show()
    else:
        plt.close()

In [None]:
def filter_station_streams(db_station, freq_band=(None, None)):
    """Perform frequency filtering on streams. Returns a replica of db_station with streams containing filtered results.
    """
    db_station_filt = defaultdict(list)
    for i, (ch, streams) in enumerate(db_station.items()):
        if ch == 'size' or i >= 3:
            continue
        for j, s in enumerate(streams):
            stream_filt = s.copy()
            if freq_band[0] is None and freq_band[1] is not None:
                stream_filt.filter('lowpass', zerophase=True, corners=2, freq=freq_band[1])
            elif freq_band[0] is not None and freq_band[1] is None:
                stream_filt.filter('highpass', zerophase=True, corners=2, freq=freq_band[0])
            elif freq_band[0] is not None and freq_band[1] is not None:
                stream_filt.filter('bandpass', zerophase=True, corners=2, freqmin=freq_band[0],
                                   freqmax=freq_band[1])
            # end if
            db_station_filt[ch].append(stream_filt)
        # end for
    # end for

    return db_station_filt

In [None]:
import copy
def filter_station_to_mean_signal(db_station, min_correlation=1.0):
    """Filter out streams which are not 'close enough' to the mean signal,
       based on simple correlation score.
    """
    # Compute mean signals of channels in station
    mean_rfs = []
    for i, (ch, streams) in enumerate(db_station.items()):
        if ch == 'size' or i >= 3:
            continue
        for j, s in enumerate(streams):
            if j == 0:
                data_mean = copy.deepcopy(s.data)
            else:
                data_mean += s.data
            # end if
        # end for
        data_mean /= np.max(data_mean)
        mean_rfs.append(data_mean)
    # end for

    # Filter out signals that do not meet minimum coherence with mean signal for each channel
    db_station_filt = defaultdict(list)

    corrs = []
    for i, (ch, streams) in enumerate(db_station.items()):
        for j, s in enumerate(streams):
            corr = np.dot(s.data, mean_rfs[i])/np.dot(mean_rfs[i], mean_rfs[i])
            if corr >= min_correlation:
                db_station_filt[ch].append(s)
            corrs.append(corr)
        # end for
    # end for
                
    return db_station_filt, corrs

In [None]:
def plot_rf_stack(rf_stream, time_window=(-10.0, 25.0), trace_height=0.2, stack_height=0.8, save_file=None):
    _ = rf_stream.plot_rf(fillcolors=('#000000', '#a0a0a0'), trim=time_window, trace_height=trace_height,
                          stack_height=stack_height, fname=save_file)

In [None]:
def add_new_metadata(db_station):
    for ch, traces in db_station.items():
        for tr in traces:
            rms_amp = np.sqrt(np.mean(np.square(tr.data)))
            cplx_amp = np.abs(hilbert(tr.data))
            mean_cplx_amp = np.mean(cplx_amp)
            amp_20pc = np.percentile(cplx_amp, 20)
            amp_80pc = np.percentile(cplx_amp, 80)
            tr.stats.rms_amp = rms_amp
            tr.stats.mean_cplx_amp = mean_cplx_amp
            tr.stats.amp_20pc = amp_20pc
            tr.stats.amp_80pc = amp_80pc
        # end for
    # end for

## Convert RFStream to dict database for convenient iteration

In [None]:
type(oa_all)

In [None]:
db = rf_to_dict(oa_all)

In [None]:
# import pandas as pd
# sta_codes = list(db.keys())
# sta_lat = [db[sta]['HHR'][0].stats.station_latitude for sta in sta_codes]
# sta_lon = [db[sta]['HHR'][0].stats.station_longitude for sta in sta_codes]
# df_OA_sta = pd.DataFrame.from_dict({'Station': sta_codes, 'Latitude': sta_lat, 'Longitude': sta_lon})
# df_OA_sta.to_csv('OA_stations.txt', index=False)

In [None]:
# # Get coordinates of select stations for north-south lines
# for sta in ['BV21', 'BV28', 'BX20', 'BX28', 'BY20', 'BY28', 'CB20', 'CB28', 'CD21', 'CD28']:
#     lat = db[sta]['HHR'][0].stats.station_latitude
#     lon = db[sta]['HHR'][0].stats.station_longitude
#     print("{}: ({} {})".format(sta, lat, lon))

In [None]:
# # Get coordinates of select stations for west-east lines
# for sta in ['BU22', 'CI22', 'BS26', 'CE26', 'BS28', 'CF28']:
#     lat = db[sta]['HHR'][0].stats.station_latitude
#     lon = db[sta]['HHR'][0].stats.station_longitude
#     print("{}: ({} {})".format(sta, lat, lon))

In [None]:
# # Get coordinates of select stations for Andrew Clark
# for i, sta_pair in enumerate([('BV21', 'CC28'), ('BX20', 'BX28'), ('BY20', 'BY28')]):
#     if i == 0:
#         # NE-->SW line
#         lat0 = db[sta_pair[0]]['HHR'][0].stats.station_latitude + 0.1
#         lon0 = db[sta_pair[0]]['HHR'][0].stats.station_longitude - 0.1
#         lat1 = db[sta_pair[1]]['HHR'][0].stats.station_latitude - 0.1
#         lon1 = db[sta_pair[1]]['HHR'][0].stats.station_longitude + 0.1
#     else:
#         # N-->S line
#         lat0 = db[sta_pair[0]]['HHR'][0].stats.station_latitude + 0.1
#         lon0 = db[sta_pair[0]]['HHR'][0].stats.station_longitude
#         lat1 = db[sta_pair[1]]['HHR'][0].stats.station_latitude - 0.1
#         lon1 = db[sta_pair[1]]['HHR'][0].stats.station_longitude
#     print("{}-{}: --start-latlon {} {} --end-latlon {} {}".format(sta_pair[0], sta_pair[1], lat0, lon0, lat1, lon1))

## Select test station and channel

In [None]:
test_station = 'BT23'
oa_test = db[test_station]
# oa_test = db['BS27']
# oa_test = db['BZ20']

In [None]:
channel = 'HHQ'

In [None]:
len(oa_test[channel])

In [None]:
np.sum([np.any(np.isnan(tr.data)) for tr in oa_test[channel]])

## Add additional statistics for discrimination between trace quality

In [None]:
add_new_metadata(oa_test)

## Examine available metadata in each trace

In [None]:
type(oa_test[channel])

In [None]:
type(oa_test[channel][0])

In [None]:
oa_test[channel][0].stats

## Display ranges of metadata and quality metrics

In [None]:
def get_metadata_series(traces, field):
    x = [tr.stats.get(field) for tr in traces]
    return x

In [None]:
# Extract metadata and quality data on all traces for the target channel
snr = get_metadata_series(oa_test[channel], 'snr')
entropy = get_metadata_series(oa_test[channel], 'entropy')
coherence = get_metadata_series(oa_test[channel], 'max_coherence')
distance = get_metadata_series(oa_test[channel], 'distance')
inclination = get_metadata_series(oa_test[channel], 'inclination')
magnitude = get_metadata_series(oa_test[channel], 'event_magnitude')
depth = get_metadata_series(oa_test[channel], 'event_depth')
amax = get_metadata_series(oa_test[channel], 'amax')
amp_20pc = get_metadata_series(oa_test[channel], 'amp_20pc')
amp_80pc = get_metadata_series(oa_test[channel], 'amp_80pc')
mean_cplx_amp = get_metadata_series(oa_test[channel], 'mean_cplx_amp')
rf_group = get_metadata_series(oa_test[channel], 'rf_group')
rms_amp = get_metadata_series(oa_test[channel], 'rms_amp')
# Replace no-group group IDs with '-1'
rf_group = [g if g is not None else -1 for g in rf_group]

In [None]:
dist_array = [(snr, "SNR"), (entropy, "Entropy"), (coherence, "Coherence"), (distance, "Distance"),
              (inclination, "Inclination"), (magnitude, "Magnitude"), (amax, "Max amplitude"), (amp_20pc, "Amplitude 20th perc."),
              (amp_80pc, "Amplitude 80th perc."), (mean_cplx_amp, "Mean amplitude"), (rms_amp, "RMS amplitude"), (rf_group, "Group ID")]

In [None]:
plt.figure(figsize=(20, 15))
plt.subplot(4,3,1)
for i, (data, name) in enumerate(dist_array):
    ax = plt.subplot(4, 3, i + 1)
#     plt.hist(data, bins=20)
    sns.distplot(data, bins=20, ax=ax)
    plt.title(name + " distribution", y=0.88, fontweight='bold')
plt.show()

In [None]:
# Examine co-plots to look for discriminating variables
df = pd.DataFrame.from_dict({"SNR": snr, "Entropy": entropy, "Coherence": coherence, "Max_amp": amax,
                             "Amp_20pc": amp_20pc, "Amp_80pc": amp_80pc, "RMS_amp": rms_amp, "Mean_amp": mean_cplx_amp,
                             "Magnitude": ">=6", "Distance": ">=60", "Depth": ">=80km",
                             "Inclination": ">=20", "Group_id": rf_group,
                             "Quality": "unknown"})
df.loc[(np.array(magnitude) < 6.0), "Magnitude"] = "<6"
df.loc[(np.array(distance) < 60.0), "Distance"] = "<60"
df.loc[(np.array(inclination) < 20.0), "Inclination"] = "<20"
df.loc[(np.array(depth) < 80.0), "Depth"] = "<80km"

In [None]:
qual_file = test_station + "_quality.csv"
if os.path.isfile(qual_file):
    loaded_quality = pd.read_csv(qual_file)
    df['Quality'] = loaded_quality

### Use interactive widget to manually label the quality of the traces

In [None]:
print("Quality guide:")
print("'a' = low signal before onset, higher signal after onset with some multiples visible")
print("'b' = signal similar before and after onset, cannot make out multiples with much confidence")
print("Create labels by entering 10 character string of 'a's and 'b's according to quality, ordered from bottom to top trace.")
# Create labels for quality. Note that rf plots are numbered from the bottom up, whereas the Pandas table is displayed ordered from the top down.
quality_updated = False
for i in range(0, len(df), 10):
    existing_qual = df['Quality'].iloc[i:i+10].values
    if not 'unknown' in existing_qual:
        continue
    rf_slice = rf.RFStream(oa_test[channel][i:i+10])
    plot_rf_stack(rf_slice, trace_height=0.4)
    plt.show()
    get_labels = ''
    quit = False
    while len(get_labels) != len(rf_slice):
        get_labels = input("Enter labels: ")
        if get_labels.lower() == 'quit':
            quit = True
            break
        if len(get_labels) != len(rf_slice):
            print("Wrong number of labels, try again!")
    if quit:
        break
    for j, qual in enumerate(get_labels):
        df['Quality'].iloc[i+j] = qual
    quality_updated = True
    display(df.iloc[i:i+10])

if quality_updated:
    df['Quality'].to_csv(qual_file)

In [None]:
# Assign quality category to trace metadata
for i, tr in enumerate(oa_test[channel]):
    tr.stats.quality = df['Quality'].iloc[i]

### Plot labelled data to find metrics to discriminate trace quality

In [None]:
@interact_manual
def metrics_pairplot(hue_by=['Quality', 'Magnitude', 'Distance', 'Depth', 'Inclination', 'Group_id']):
    hue_order = None
    if hue_by == 'Quality':
        hue_order = ['unknown', 'b', 'a'] if 'unknown' in df['Quality'] else ['b', 'a']
    sns.pairplot(df, hue=hue_by, hue_order=hue_order, vars=["SNR", "Entropy", "Coherence", "Max_amp", "Amp_20pc", "Amp_80pc", "RMS_amp", "Mean_amp"])
    plt.suptitle("Pairwise quality metrics scatter plot", y=1.01, fontsize=20)
#     plt.show()

## Look at how effective selected metadata metrics are at filtering to the Quality A set of events

In [None]:
num_total = len(oa_test[channel])

rf_data = [tr for tr in oa_test[channel] if tr.stats.quality == 'a']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream_A = rf.RFStream(rf_data)
print("Quality A: {} events".format(len(rf_stream_A)))
quality_A_ids = [tr.stats.event_id for tr in rf_stream_A]
not_quality_A_ids = [tr.stats.event_id for tr in oa_test[channel] if tr.stats.event_id not in quality_A_ids]

rf_data = [tr for tr in oa_test[channel] if tr.stats.snr >= 1.5 and tr.stats.entropy >= 3.0 and tr.stats.max_coherence >= 0.15]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream_stats_filtered = rf.RFStream(rf_data)
num_filtered = len(rf_stream_stats_filtered)
print("Stats filtered: {} events".format(num_filtered))
stats_filtered_ids = [tr.stats.event_id for tr in rf_stream_stats_filtered]
true_positives = [id for id in stats_filtered_ids if id in quality_A_ids]
false_negatives = [id for id in quality_A_ids if id not in stats_filtered_ids]
num_true_positive = len(true_positives)
num_false_negative = len(false_negatives)
num_predicted_positive = len(stats_filtered_ids)
num_predicted_negative = num_total - num_predicted_positive

# Determine how many of the events in stats_filtered_ids are Quality A events
print("{}/{} correct filtered events (snr, entropy, coherence) (Positive predictive value = {:.2f}%, False omission rate = {:.2f}%)"
      .format(num_qual_A, num_filtered, 100.0*num_true_positive/num_predicted_positive, 100*num_false_negative/num_predicted_negative))

# Repeat using amplitude metrics
rf_data = [tr for tr in oa_test[channel] if tr.stats.amax <= 0.3 and tr.stats.amp_20pc <= 0.03 and tr.stats.amp_80pc <= 0.1]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream_stats2_filtered = rf.RFStream(rf_data)
num2_filtered = len(rf_stream_stats2_filtered)
print("Stats2 filtered: {} events".format(num2_filtered))
stats2_filtered_ids = [tr.stats.event_id for tr in rf_stream_stats2_filtered]
true_positives = [id for id in stats2_filtered_ids if id in quality_A_ids]
false_negatives = [id for id in quality_A_ids if id not in stats2_filtered_ids]
num_true_positive = len(true_positives)
num_false_negative = len(false_negatives)
num_predicted_positive = len(stats2_filtered_ids)
num_predicted_negative = num_total - num_predicted_positive

print("{}/{} filtered events (Max. amp, 20%, 80%) are quality A events (Positive predictive value = {:.2f}%, False omission rate = {:.2f}%)"
      .format(num2_qual_A, num2_filtered, 100.0*num_true_positive/num_predicted_positive, 100*num_false_negative/num_predicted_negative))


## Plot RFs for traces filtered by various quality metrics

### Quality A

In [None]:
plot_rf_stack(rf_stream_A)

### Quality B

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.quality == 'b']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

### High SNR

In [None]:
import importlib
importlib.reload(rf)
rf_data = [tr for tr in oa_test[channel] if tr.stats.snr >= 3.0]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
# plot_rf_stack(rf_stream, save_file="sample_rf_aligners.png")
plot_rf_stack(rf_stream)

### Low SNR

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.snr <= 0.8]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### Low entropy

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.entropy <= 3.0]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### High entropy

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.entropy >= 4.2]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### High coherence

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.max_coherence >= 0.3]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### Low coherence

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.max_coherence <= 0.02]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### High magnitude

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.event_magnitude >= 5.5]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

### Low magnitude

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.event_magnitude < 5.5]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

***

## Plot overlay of all traces in test channel (no filtering)

In [None]:
oa_quality = {channel: [tr for tr in rf_stream_A]}

In [None]:
num_traces = len(oa_quality[channel])
trace_mean = plot_station_rf_overlays(oa_quality, '(all {} traces)'.format(num_traces))

## Split traces into groups and plot each group

In [None]:
group_dict = {}
for tr in oa_quality[channel]:
    grp = tr.stats.get('rf_group')
    if grp is not None:
        if grp in group_dict:
            group_dict[grp][channel].append(tr)
        else:
            group_dict[grp] = {}
            group_dict[grp][channel] = [tr]

groups = group_dict.keys()
print("Found {} groups: {}".format(len(groups), groups))

In [None]:
for grp_id, group in group_dict.items():
    num_traces = len(group[channel])
    title = '(group {}, {} traces)'.format(grp_id, num_traces)
    group_mean = plot_station_rf_overlays(group, title)

## Plot only traces with similarity to the mean

In [None]:
oa_quality_filt, corrs = filter_station_to_mean_signal(oa_quality, min_correlation=0.05)

In [None]:
plt.hist(corrs, bins=50)
plt.show()

In [None]:
num_traces = len(oa_quality_filt[channel])
test_filt_mean = plot_station_rf_overlays(oa_quality_filt, '({} traces similar to mean)'.format(num_traces))

## Demonstrate the effectiveness of phase-weighting the traces

In [None]:
from seismic.receiver_fn.rf_util import phase_weights

In [None]:
pw = phase_weights(oa_quality_filt[channel])

In [None]:
s0 = oa_quality_filt[channel][0]
time_offset = s0.stats.onset - s0.stats.starttime
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, pw)
plt.title('Phase weightings')
plt.grid()
plt.show()

In [None]:
# Demonstrate effect of phase weighting to suppress areas where phases tend to be random.
pw_exponent = 2
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, s0.data, linewidth=2)
plt.plot(s0.times() - time_offset, s0.data*pw**pw_exponent, '--', linewidth=2)
plt.legend(['Original', 'Phase weighted'])
plt.title('Phase weighting applied to a single trace')
plt.grid()
plt.show()

In [None]:
# Apply phase weighting to data for H-k stacking
# NOTE: This will overwrite the original filtered data
for tr in oa_quality_filt[channel]:
    tr.data = tr.data*pw**pw_exponent

num_traces = len(oa_quality_filt[channel])
test_filt_mean = plot_station_rf_overlays(oa_quality_filt, '({} traces similar to mean, phase weighted)'.format(num_traces))

# Plot HK stacks

In [None]:
hk_src_data = oa_quality_filt

In [None]:
# Plot stack
weighting = (0.35, 0.35, 0.3)

for cha in [channel]:
    k_grid, h_grid, hk_stack = compute_hk_stack(hk_src_data, cha, root_order=2)

    hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
    
    sta = hk_src_data[cha][0].stats.station

    num = len(hk_src_data[cha])
    save_file = None
    plot_hk_stack(k_grid, h_grid, hk_stack[0], title=sta + '.{} Ps'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack[1], title=sta + '.{} PpPs'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack[2], title=sta + '.{} PpSs + PsPs'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha) + ' (no filtering)', num=num, save_file=save_file)

***

# Loop over all OA stations and plot HK-stacks

In [None]:
# cha = channel
# pbar = tqdm(total=len(db))
# show = False
# weighting = (0.5, 0.4, 0.1)
# for sta, db_sta in db.items():
#     pbar.set_description(sta)
#     pbar.update()
#     k_grid, h_grid, hk_stack = compute_hk_stack(db_sta, cha, root_order=2)
#     hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
#     sta = db_sta[cha][0].stats.station
#     save_file = sta + "_{}_hk_stack.png".format(cha)
#     num = len(db_sta[cha])
#     plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha), save_file=save_file, show=show, num=num)
# pbar.close()