In [None]:
import os
# from collections import defaultdict
# import time

import numpy as np
import rf
import rf.imaging
import matplotlib.pyplot as plt
import scipy
from scipy import signal
from scipy.signal import hilbert
from scipy.stats import moment
# from scipy.interpolate import interp1d
import obspy
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

In [None]:
# Bring in interactive widgets capability. See https://towardsdatascience.com/interactive-controls-for-jupyter-notebooks-f5c94829aee6
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

In [None]:
import seismic.receiver_fn.rf_util as rf_util
import seismic.receiver_fn.rf_plot_utils as rf_plot_utils
import seismic.receiver_fn.rf_stacking as rf_stacking

## Read source file

In [None]:
src_file = r"..\DATA\OA_event_waveforms_for_rf_20170911T000036-20181128T230620_LQT_td_rev3_qual.h5"

In [None]:
oa_all = rf_util.read_h5_rf(src_file)

In [None]:
type(oa_all)

## Convert RFStream to dict database for convenient iteration and addressing

In [None]:
db = rf_util.rf_to_dict(oa_all)

## Select test station and channel

In [None]:
test_station = 'BT23'
# test_station = 'BS27'
# test_station = 'BZ20'
oa_test = db[test_station]

In [None]:
channel = 'HHQ'

In [None]:
len(oa_test[channel])

In [None]:
# Check if there are any traces with NaNs in them. RF quality filtering prior to this SHOULD have removed any such traces.
np.sum([np.any(np.isnan(tr.data)) for tr in oa_test[channel]])

## Add additional statistics for prediction of trace quality

In [None]:
rf_util.compute_extra_rf_stats(oa_test)

## Examine available metadata in each trace

In [None]:
type(oa_test[channel])

In [None]:
type(oa_test[channel][0])

In [None]:
oa_test[channel][0].stats

## Display ranges of metadata and quality metrics

In [None]:
def get_metadata_series(traces, field):
    x = [tr.stats.get(field) for tr in traces]
    return x

In [None]:
# Extract metadata and quality data on all traces for the target channel
snr = get_metadata_series(oa_test[channel], 'snr')
entropy = get_metadata_series(oa_test[channel], 'entropy')
coherence = get_metadata_series(oa_test[channel], 'max_coherence')
distance = get_metadata_series(oa_test[channel], 'distance')
inclination = get_metadata_series(oa_test[channel], 'inclination')
magnitude = get_metadata_series(oa_test[channel], 'event_magnitude')
depth = get_metadata_series(oa_test[channel], 'event_depth')
amax = get_metadata_series(oa_test[channel], 'amax')
amp_20pc = get_metadata_series(oa_test[channel], 'amp_20pc')
amp_80pc = get_metadata_series(oa_test[channel], 'amp_80pc')
mean_cplx_amp = get_metadata_series(oa_test[channel], 'mean_cplx_amp')
rf_group = get_metadata_series(oa_test[channel], 'rf_group')
rms_amp = get_metadata_series(oa_test[channel], 'rms_amp')
# Replace no-group group IDs with '-1'
rf_group = [g if g is not None else -1 for g in rf_group]

In [None]:
dist_array = [(snr, "SNR"), (entropy, "Entropy"), (coherence, "Coherence"), (distance, "Distance"),
              (inclination, "Inclination"), (magnitude, "Magnitude"), (amax, "Max amplitude"), (amp_20pc, "Amplitude 20th perc."),
              (amp_80pc, "Amplitude 80th perc."), (mean_cplx_amp, "Mean amplitude"), (rms_amp, "RMS amplitude"), (rf_group, "Group ID")]

In [None]:
plt.figure(figsize=(20, 15))
plt.subplot(4,3,1)
for i, (data, name) in enumerate(dist_array):
    ax = plt.subplot(4, 3, i + 1)
#     plt.hist(data, bins=20)
    sns.distplot(data, bins=20, ax=ax)
    plt.title(name + " distribution", y=0.88, fontweight='bold')
plt.show()

In [None]:
# Examine co-plots to look for discriminating variables
df = pd.DataFrame.from_dict({"SNR": snr, "Entropy": entropy, "Coherence": coherence, "Max_amp": amax,
                             "Amp_20pc": amp_20pc, "Amp_80pc": amp_80pc, "RMS_amp": rms_amp, "Mean_amp": mean_cplx_amp,
                             "Magnitude": ">=6", "Distance": ">=60", "Depth": ">=80km",
                             "Inclination": ">=20", "Group_id": rf_group,
                             "Quality": "unknown"})
df.loc[(np.array(magnitude) < 6.0), "Magnitude"] = "<6"
df.loc[(np.array(distance) < 60.0), "Distance"] = "<60"
df.loc[(np.array(inclination) < 20.0), "Inclination"] = "<20"
df.loc[(np.array(depth) < 80.0), "Depth"] = "<80km"

In [None]:
qual_file = test_station + "_quality.csv"
if os.path.isfile(qual_file):
    loaded_quality = pd.read_csv(qual_file, index_col=0, header=None)
    df['Quality'] = loaded_quality

### Use interactive widget to manually label the quality of the traces

In [None]:
print("Quality guide:")
print("'a' = low signal before onset, higher signal after onset with some multiples visible")
print("'b' = signal similar before and after onset, cannot make out multiples with much confidence")
print("Create labels by entering 10 character string of 'a's and 'b's according to quality, ordered from bottom to top trace.")
# Create labels for quality. Note that rf plots are numbered from the bottom up, whereas the Pandas table is displayed ordered from the top down.
quality_updated = False
for i in range(0, len(df), 10):
    existing_qual = df['Quality'].iloc[i:i+10].values
    if not 'unknown' in existing_qual:
        continue
    rf_slice = rf.RFStream(oa_test[channel][i:i+10])
    plot_rf_stack(rf_slice, trace_height=0.4)
    plt.show()
    get_labels = ''
    quit = False
    while len(get_labels) != len(rf_slice):
        get_labels = input("Enter labels: ")
        if get_labels.lower() == 'quit':
            quit = True
            break
        if len(get_labels) != len(rf_slice):
            print("Wrong number of labels, try again!")
    if quit:
        break
    for j, qual in enumerate(get_labels):
        df['Quality'].iloc[i+j] = qual
    quality_updated = True
    display(df.iloc[i:i+10])

if quality_updated:
    df['Quality'].to_csv(qual_file)
else:
    display(df)

In [None]:
# Assign quality category to trace metadata
for i, tr in enumerate(oa_test[channel]):
    tr.stats.quality = df['Quality'].iloc[i]

### Plot labelled data to find metrics to discriminate trace quality

In [None]:
stats_metrics = ["SNR", "Entropy", "Coherence", "Max_amp", "Amp_20pc", "Amp_80pc", "RMS_amp", "Mean_amp"]

In [None]:
@interact_manual
def metrics_pairplot(hue_by=['Quality', 'Magnitude', 'Distance', 'Depth', 'Inclination', 'Group_id']):
    hue_order = None
    if hue_by == 'Quality':
        hue_order = ['unknown', 'b', 'a'] if 'unknown' in df['Quality'] else ['b', 'a']
    sns.pairplot(df, hue=hue_by, hue_order=hue_order, vars=stats_metrics)
    plt.suptitle("Pairwise quality metrics scatter plot", y=1.01, fontsize=20)
#     plt.show()

## Look at how effective selected metadata metrics are at filtering to the Quality A set of events

In [None]:
num_total = len(oa_test[channel])

rf_data = [tr for tr in oa_test[channel] if tr.stats.quality == 'a']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream_A = rf.RFStream(rf_data)
print("Quality A: {} events".format(len(rf_stream_A)))
quality_A_ids = [tr.stats.event_id for tr in rf_stream_A]
not_quality_A_ids = [tr.stats.event_id for tr in oa_test[channel] if tr.stats.event_id not in quality_A_ids]

rf_data = [tr for tr in oa_test[channel] if tr.stats.snr >= 1.5 and tr.stats.entropy >= 3.0 and tr.stats.max_coherence >= 0.15]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream_stats_filtered = rf.RFStream(rf_data)
num_filtered = len(rf_stream_stats_filtered)
print("Stats filtered: {} events".format(num_filtered))
stats_filtered_ids = [tr.stats.event_id for tr in rf_stream_stats_filtered]
true_positives = [id for id in stats_filtered_ids if id in quality_A_ids]
false_negatives = [id for id in quality_A_ids if id not in stats_filtered_ids]
num_true_positive = len(true_positives)
num_false_negative = len(false_negatives)
num_predicted_positive = len(stats_filtered_ids)
num_predicted_negative = num_total - num_predicted_positive

# Determine how many of the events in stats_filtered_ids are Quality A events
print("{}/{} correct filtered events (snr, entropy, coherence) (Positive predictive value = {:.2f}%, False omission rate = {:.2f}%)"
      .format(num_true_positive, num_filtered, 100.0*num_true_positive/num_predicted_positive, 100*num_false_negative/num_predicted_negative))

# Repeat using amplitude metrics
rf_data = [tr for tr in oa_test[channel] if tr.stats.amax <= 0.3 and tr.stats.amp_20pc <= 0.03 and tr.stats.amp_80pc <= 0.1]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream_stats2_filtered = rf.RFStream(rf_data)
num2_filtered = len(rf_stream_stats2_filtered)
print("Stats2 filtered: {} events".format(num2_filtered))
stats2_filtered_ids = [tr.stats.event_id for tr in rf_stream_stats2_filtered]
true_positives = [id for id in stats2_filtered_ids if id in quality_A_ids]
false_negatives = [id for id in quality_A_ids if id not in stats2_filtered_ids]
num_true_positive = len(true_positives)
num_false_negative = len(false_negatives)
num_predicted_positive = len(stats2_filtered_ids)
num_predicted_negative = num_total - num_predicted_positive

print("{}/{} filtered events (Max. amp, 20%, 80%) are quality A events (Positive predictive value = {:.2f}%, False omission rate = {:.2f}%)"
      .format(num_true_positive, num2_filtered, 100.0*num_true_positive/num_predicted_positive, 100*num_false_negative/num_predicted_negative))

# The performance stats shown below show what a human achieves trying to tune data selection criteria manually.

## See how well a neural network classifier works in comparison

### Use simple stats for feature vector

In [None]:
from sklearn.neural_network import MLPClassifier
# from sklearn import preprocessing
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import confusion_matrix
from sklearn.svm import LinearSVC, SVC
from sklearn import tree

In [None]:
X = df.loc[:, stats_metrics].values
X[np.isnan(X)] = 0
y = df['Quality'].values

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

In [None]:
# scaler = preprocessing.StandardScaler().fit(X_train)
# X_train_transformed = scaler.transform(X_train)

In [None]:
# This perceptron network has been simplified back to the bare bone so that it corresponds to a linear predictor,
# as higher order complexity and non-linear activation functions gave no improvement in accuracy.
clf_simple = MLPClassifier(solver='lbfgs', alpha=1e-4, max_iter=1000, activation='identity',
                           hidden_layer_sizes=(1,), random_state=3772, tol=1e-4)

In [None]:
# Run cross-validation to tune hyperparameters
scores = cross_val_score(clf_simple, X_train, y_train, cv=5)
print(scores)
print("Accuracy: %0.3f (+/- %0.3f)" % (scores.mean(), scores.std()*2))

In [None]:
# With tuned hyperparameters, train on full training set.
clf_simple.fit(X_train, y_train)
print("Final loss: %0.4f" % clf_simple.loss_)

In [None]:
final_score = clf_simple.score(X_test, y_test)
print("Final accuracy: %0.3f" % final_score)
# We get decent performance with a trivial network (1 neuron) with trivial activation f(x) = x,
# which means that simply a linear combination of feature vector is sufficient to determine
# classification.

In [None]:
test_prediction = clf_simple.predict(X_test)
full_prediction = clf_simple.predict(X)
df['Prediction'] = full_prediction
df.sample(20, random_state=3772)

In [None]:
print(stats_metrics)
print(clf_simple.coefs_[0].T[0])
print(clf_simple.coefs_[1][0])
print(clf_simple.intercepts_[0], clf_simple.intercepts_[1])
A0 = clf_simple.coefs_[0].T[0]
b0 = clf_simple.intercepts_[0][0]
A1 = clf_simple.coefs_[1][0][0]
b1 = clf_simple.intercepts_[1][0]

In [None]:
# Compute linear combination of features according to solver weightings, to check how to use these directly for class prediction.
# Should give exact same result as MLPClassifier.
lin_comb_prediction = A1*(np.matmul(X_test, A0) + b0) + b1

In [None]:
df_predictions = pd.DataFrame.from_dict({"Truth": y_test, "MLP prediction": test_prediction, "Lin. predictor": lin_comb_prediction})
df_predictions['Lin. predictor'].loc[(lin_comb_prediction < 0)] = 'a'
df_predictions['Lin. predictor'].loc[(lin_comb_prediction >= 0)] = 'b'
assert np.all(df_predictions['Lin. predictor'] == df_predictions['MLP prediction'])
df_predictions.sample(10, random_state=3772)

In [None]:
# Display confusion matrix and verify how to compute accuracy from it.
cm = confusion_matrix(df_predictions['Truth'], df_predictions['MLP prediction'], labels=['b', 'a'])
print(cm)
print(np.sum(cm))
print("Accuracy: %0.3f" % (np.sum(np.diag(cm))/float(np.sum(cm))))

In [None]:
# Look at how good is the DBSCAN grouping as an indicator of trace quality.
dbscan_group = df['Group_id'].copy()
primary_group_mask = (dbscan_group == 0)
dbscan_group[primary_group_mask] = 'a'
dbscan_group[~primary_group_mask] = 'b'
cm_dbscan = confusion_matrix(df['Quality'], dbscan_group, labels=['b', 'a'])
print(cm_dbscan)
print(np.sum(cm_dbscan))
print("Accuracy: %0.3f" % (np.sum(np.diag(cm_dbscan))/float(np.sum(cm_dbscan))))
# Result here indicates DBSCAN grouping is not a strong predictor of subjective trace quality

In [None]:
# Look at how good SNR alone is as an indicator of trace quality.
snr_series = df['SNR'].copy()
high_snr_mask = (snr_series >= 1.5)
snr_series[high_snr_mask] = 'a'
snr_series[~high_snr_mask] = 'b'
cm_snr = confusion_matrix(df['Quality'], snr_series, labels=['b', 'a'])
print(cm_snr)
print(np.sum(cm_snr))
print("Accuracy: %0.3f" % (np.sum(np.diag(cm_snr))/float(np.sum(cm_snr))))
# Result here indicates SNR alone is quite a good indicator of quality

In [None]:
# Assign PREDICTED quality category to trace metadata
for i, tr in enumerate(oa_test[channel]):
    tr.stats.predicted_quality = df['Prediction'].iloc[i]

### Use (-10, 25) trimmed waveform for feature vector

In [None]:
all_traces = rf.RFStream([tr for tr in oa_test[channel]])
all_traces = all_traces.slice2(-10.0, 25.0, reftime='onset')

In [None]:
X = np.array([tr.data for tr in all_traces])
y = df['Quality'].values
print(X.shape)

In [None]:
# # Convert into frequency domain
# _, X = signal.periodogram(X)
# print(X.shape)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=1)

In [None]:
scaler = preprocessing.StandardScaler().fit(X_train)
X_train_transformed = scaler.transform(X_train)

In [None]:
clf_rf = MLPClassifier(solver='lbfgs', alpha=1e-4, max_iter=1000, activation='relu',
                       hidden_layer_sizes=(20,), random_state=3772, tol=1e-4)

In [None]:
scores = cross_val_score(clf_rf, X_train_transformed, y_train, cv=5)
print(scores)
print("Accuracy: %0.3f (+/- %0.3f)" % (scores.mean(), scores.std()*2))

In [None]:
clf_rf.fit(X_train_transformed, y_train)
print("Final loss: %0.4f" % clf_rf.loss_)

In [None]:
X_test_transformed = scaler.transform(X_test)
final_score_rf = clf_rf.score(X_test_transformed, y_test)
print("Final accuracy: %0.3f" % final_score_rf)

### Try Support Vector Classifier instead to assess performance

In [None]:
X = df.loc[:, stats_metrics].values
X[np.isnan(X)] = 0
y = df['Quality'].values

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

In [None]:
# scaler = preprocessing.StandardScaler().fit(X_train)
# X_train_transformed = scaler.transform(X_train)

In [None]:
clf_svc = LinearSVC(random_state=3772)
# clf_svc = SVC(C=1.0, gamma='scale', kernel='rbf', random_state=3772)

In [None]:
# Run cross-validation to tune hyperparameters
scores = cross_val_score(clf_svc, X_train, y_train, cv=5)
print(scores)
print("Accuracy: %0.3f (+/- %0.3f)" % (scores.mean(), scores.std()*2))

In [None]:
# With tuned hyperparameters, train on full training set.
clf_svc.fit(X_train, y_train)
print("Final loss: %0.4f" % clf_simple.loss_)

In [None]:
final_score = clf_svc.score(X_test, y_test)
print("Final accuracy: %0.3f" % final_score)
# This result is pretty much exactly same as above MLClassifier with one neuron, i.e. it's just a linear predictor.

### Try classification using decision tree

In [None]:
X = df.loc[:, stats_metrics].values
X[np.isnan(X)] = 0
y = df['Quality'].values

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

In [None]:
clf_tree = tree.DecisionTreeClassifier(random_state=3772)

In [None]:
clf_tree = clf_tree.fit(X_train, y_train)

In [None]:
# Run cross-validation to tune hyperparameters
scores = cross_val_score(clf_tree, X_train, y_train, cv=5)
print(scores)
print("Accuracy: %0.3f (+/- %0.3f)" % (scores.mean(), scores.std()*2))

## Plot RFs for traces filtered by various quality metrics

### Quality A

In [None]:
plot_rf_stack(rf_stream_A)

### Quality B

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.quality == 'b']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

### High SNR

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.snr >= 3.0]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### Low SNR

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.snr <= 0.8]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### High entropy

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.entropy >= 4.2]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### Low entropy

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.entropy <= 3.0]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### High coherence

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.max_coherence >= 0.3]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### Low coherence

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.max_coherence <= 0.02]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream)

### High magnitude

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.event_magnitude >= 5.5]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

### Low magnitude

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.event_magnitude < 5.5]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

### Predicted Quality A

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.predicted_quality == 'a']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

### Predicted Quality B

In [None]:
rf_data = [tr for tr in oa_test[channel] if tr.stats.predicted_quality == 'b']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
plot_rf_stack(rf_stream[0:100])

***

## Plot overlay of all traces in test channel (no filtering)

In [None]:
oa_quality = {channel: [tr for tr in rf_stream_A]}

In [None]:
num_traces = len(oa_quality[channel])
trace_mean = rf_plot_utils.plot_station_rf_overlays(oa_quality, '(all {} traces)'.format(num_traces))

## Split traces into groups and plot each group

In [None]:
group_dict = {}
for tr in oa_quality[channel]:
    grp = tr.stats.get('rf_group')
    if grp is not None:
        if grp in group_dict:
            group_dict[grp][channel].append(tr)
        else:
            group_dict[grp] = {}
            group_dict[grp][channel] = [tr]

groups = group_dict.keys()
print("Found {} groups: {}".format(len(groups), groups))

In [None]:
for grp_id, group in group_dict.items():
    num_traces = len(group[channel])
    title = '(group {}, {} traces)'.format(grp_id, num_traces)
    group_mean = rf_plot_utils.plot_station_rf_overlays(group, title)

## Plot only traces with similarity to the mean

In [None]:
oa_quality_filt, corrs = filter_station_to_mean_signal(oa_quality, min_correlation=0.05)

In [None]:
plt.hist(corrs, bins=50)
plt.show()

In [None]:
num_traces = len(oa_quality_filt[channel])
test_filt_mean = rf_plot_utils.plot_station_rf_overlays(oa_quality_filt, '({} traces similar to mean)'.format(num_traces))

## Demonstrate the effectiveness of phase-weighting the traces

In [None]:
from seismic.receiver_fn.rf_util import phase_weights

In [None]:
pw = phase_weights(oa_quality_filt[channel])

In [None]:
s0 = oa_quality_filt[channel][0]
time_offset = s0.stats.onset - s0.stats.starttime
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, pw)
plt.title('Phase weightings')
plt.grid()
plt.show()

In [None]:
# Demonstrate effect of phase weighting to suppress areas where phases tend to be random.
pw_exponent = 2
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, s0.data, linewidth=2)
plt.plot(s0.times() - time_offset, s0.data*pw**pw_exponent, '--', linewidth=2)
plt.legend(['Original', 'Phase weighted'])
plt.title('Phase weighting applied to a single trace')
plt.grid()
plt.show()

In [None]:
# # Apply phase weighting to data for H-k stacking
# # NOTE: This will overwrite the original filtered data
# for tr in oa_quality_filt[channel]:
#     tr.data = tr.data*pw**pw_exponent

# num_traces = len(oa_quality_filt[channel])
# test_filt_mean = rf_plot_utils.plot_station_rf_overlays(oa_quality_filt, '({} traces similar to mean, phase weighted)'.format(num_traces))

# Plot HK stacks

In [None]:
hk_src_data = oa_quality_filt

In [None]:
# Plot stack
weighting = (0.35, 0.35, 0.3)

for cha in [channel]:
    k_grid, h_grid, hk_stack = compute_hk_stack(hk_src_data, cha, root_order=2)

    hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
    
    sta = hk_src_data[cha][0].stats.station

    num = len(hk_src_data[cha])
    save_file = None
    plot_hk_stack(k_grid, h_grid, hk_stack[0], title=sta + '.{} Ps'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack[1], title=sta + '.{} PpPs'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack[2], title=sta + '.{} PpSs + PsPs'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha) + ' (no filtering)', num=num, save_file=save_file)

***

# Loop over all OA stations and plot HK-stacks

In [None]:
# cha = channel
# pbar = tqdm(total=len(db))
# show = False
# weighting = (0.5, 0.4, 0.1)
# for sta, db_sta in db.items():
#     pbar.set_description(sta)
#     pbar.update()
#     k_grid, h_grid, hk_stack = compute_hk_stack(db_sta, cha, root_order=2)
#     hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
#     sta = db_sta[cha][0].stats.station
#     save_file = sta + "_{}_hk_stack.png".format(cha)
#     num = len(db_sta[cha])
#     plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha), save_file=save_file, show=show, num=num)
# pbar.close()