In [None]:
import os
# from collections import defaultdict
# import time
import json
import pickle as pkl
import base64

import numpy as np
import rf
import rf.imaging
import matplotlib.pyplot as plt
import scipy
from scipy import signal
from scipy.signal import hilbert
from scipy.stats import moment
# from scipy.interpolate import interp1d
from sklearn.neural_network import MLPClassifier
import obspy
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

In [None]:
# Bring in interactive widgets capability. See https://towardsdatascience.com/interactive-controls-for-jupyter-notebooks-f5c94829aee6
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

In [None]:
import seismic.receiver_fn.rf_util as rf_util
import seismic.receiver_fn.rf_plot_utils as rf_plot_utils
import seismic.receiver_fn.rf_stacking as rf_stacking

## Read source file

In [None]:
src_file = r"..\DATA\OA_event_waveforms_for_rf_20170911T000036-20181128T230620_LQT_td_rev3_qual.h5"

In [None]:
oa_all = rf_util.read_h5_rf(src_file)

In [None]:
type(oa_all)

## Convert RFStream to dict database for convenient iteration and addressing

In [None]:
db = rf_util.rf_to_dict(oa_all)

## Select test station and channel

In [None]:
# test_station = 'BT23'
test_station = 'BS27'
# test_station = 'BZ20'
oa_test = db[test_station]

In [None]:
channel = 'HHQ'
channel_data = oa_test[channel]
len(channel_data)

In [None]:
# Check if there are any traces with NaNs in them. RF quality filtering prior to this SHOULD have removed any such traces.
np.sum([np.any(np.isnan(tr.data)) for tr in channel_data])

## Add additional statistics for prediction of trace quality

In [None]:
# This needs to be done before running quality classifier over the traces.
rf_util.compute_extra_rf_stats(oa_test)

## Examine available metadata in each trace

In [None]:
type(channel_data)

In [None]:
type(channel_data[0])

In [None]:
channel_data[0].stats

## Load quality classifier

In [None]:
model['coeffs'] = pkl.loads(base64.b64decode(model['coeffs']))
model['biases'] = pkl.loads(base64.b64decode(model['biases']))
model['classes'] = np.array(model['classes'])
model['binarizer'] = pkl.loads(base64.b64decode(model['biases']))

In [None]:
qc = MLPClassifier(**model['params'])
qc.coefs_ = model['coeffs']
qc.intercepts_ = model['biases']
qc.classes_ = model['classes']
qc.out_activation_ = model['out_activation']
qc.n_outputs_ = model['n_outputs']
qc.n_layers_ = model['n_layers']
qc._label_binarizer = model['binarizer']

## Apply quality filter to traces

In [None]:
stats_metrics = ["SNR", "Entropy", "Coherence", "Max_amp", "Amp_20pc", "Amp_80pc", "RMS_amp", "Mean_amp"]
X = np.array([[tr.stats.snr, tr.stats.entropy, tr.stats.max_coherence, tr.stats.amax,
               tr.stats.amp_20pc, tr.stats.amp_80pc, tr.stats.rms_amp, tr.stats.mean_cplx_amp] for tr in channel_data])
X[np.isnan(X)] = 0

In [None]:
predicted_quality = qc.predict(X)
for i, tr in enumerate(channel_data):
    tr.stats.predicted_quality = predicted_quality[i]

In [None]:
rf_stream_A = rf.RFStream([tr for tr in channel_data if tr.stats.predicted_quality == 'a'])
len(rf_stream_A)

## Plot RFs for traces filtered by various quality metrics

In [None]:
max_traces = 50

### Quality A

In [None]:
try:
    rf_data = [tr for tr in channel_data if tr.stats.quality == 'a']
    rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
    rf_stream = rf.RFStream(rf_data)
    rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])
except AttributeError:
    print("Data has no ground truth quality labels")

### Quality B

In [None]:
try:
    rf_data = [tr for tr in channel_data if tr.stats.quality == 'b']
    rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
    rf_stream = rf.RFStream(rf_data)
    rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])
except AttributeError:
    print("Data has no ground truth quality labels")

### Predicted Quality A

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.predicted_quality == 'a']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### Predicted Quality B

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.predicted_quality == 'b']
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### High SNR

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.snr >= 3.0]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### Low SNR

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.snr <= 0.8]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### High entropy

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.entropy >= 4.2]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### Low entropy

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.entropy <= 3.0]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### High coherence

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.max_coherence >= 0.3]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### Low coherence

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.max_coherence <= 0.02]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### High magnitude

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.event_magnitude >= 5.5]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

### Low magnitude

In [None]:
rf_data = [tr for tr in channel_data if tr.stats.event_magnitude < 5.5]
rf_data = sorted(rf_data, key=lambda v: v.stats.back_azimuth)
rf_stream = rf.RFStream(rf_data)
rf_plot_utils.plot_rf_stack(rf_stream[0:max_traces])

***

## Plot overlay of all traces in test channel (no filtering)

In [None]:
oa_quality = {channel: [tr for tr in channel_data if tr.stats.predicted_quality == 'a']}

In [None]:
num_traces = len(channel_data)
trace_mean = rf_plot_utils.plot_station_rf_overlays(oa_quality, '(all {} traces)'.format(num_traces))

## Plot only traces with similarity to the mean

In [None]:
oa_quality_filt, corrs = rf_util.filter_station_to_mean_signal(oa_quality, min_correlation=0.05)

In [None]:
plt.hist(corrs, bins=50)
plt.show()

In [None]:
num_traces = len(oa_quality_filt[channel])
test_filt_mean = rf_plot_utils.plot_station_rf_overlays(oa_quality_filt, '({} traces similar to mean)'.format(num_traces))

## Demonstrate the effectiveness of phase-weighting the traces

In [None]:
from seismic.receiver_fn.rf_util import phase_weights

In [None]:
pw = phase_weights(oa_quality_filt[channel])

In [None]:
s0 = oa_quality_filt[channel][0]
time_offset = s0.stats.onset - s0.stats.starttime
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, pw)
plt.title('Phase weightings')
plt.grid()
plt.show()

In [None]:
# Demonstrate effect of phase weighting to suppress areas where phases tend to be random.
pw_exponent = 2
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, s0.data, linewidth=2)
plt.plot(s0.times() - time_offset, s0.data*pw**pw_exponent, '--', linewidth=2)
plt.legend(['Original', 'Phase weighted'])
plt.title('Phase weighting applied to a single trace')
plt.grid()
plt.show()

In [None]:
# Apply phase weighting to data for H-k stacking
# NOTE: This will overwrite the original filtered data
for tr in oa_quality_filt[channel]:
    tr.data = tr.data*pw**pw_exponent

num_traces = len(oa_quality_filt[channel])
test_filt_mean = rf_plot_utils.plot_station_rf_overlays(oa_quality_filt, '({} traces similar to mean, phase weighted)'.format(num_traces))

# Plot HK stacks

In [None]:
hk_src_data = oa_quality_filt

In [None]:
# Plot stack
weighting = (0.35, 0.35, 0.3)

for cha in [channel]:
    k_grid, h_grid, hk_stack = rf_stacking.compute_hk_stack(hk_src_data, cha, root_order=2)

    hk_stack_sum = rf_stacking.compute_weighted_stack(hk_stack, weighting)
    
    sta = hk_src_data[cha][0].stats.station

    num = len(hk_src_data[cha])
    save_file = None
    rf_plot_utils.plot_hk_stack(k_grid, h_grid, hk_stack[0], title=sta + '.{} Ps'.format(cha), num=num)
    rf_plot_utils.plot_hk_stack(k_grid, h_grid, hk_stack[1], title=sta + '.{} PpPs'.format(cha), num=num)
    rf_plot_utils.plot_hk_stack(k_grid, h_grid, hk_stack[2], title=sta + '.{} PpSs + PsPs'.format(cha), num=num)
    rf_plot_utils.plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha) + ' (no filtering)', num=num, save_file=save_file)

***

# Loop over all OA stations and plot HK-stacks

In [None]:
# cha = channel
# pbar = tqdm(total=len(db))
# show = False
# weighting = (0.5, 0.4, 0.1)
# for sta, db_sta in db.items():
#     pbar.set_description(sta)
#     pbar.update()
#     k_grid, h_grid, hk_stack = compute_hk_stack(db_sta, cha, root_order=2)
#     hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
#     sta = db_sta[cha][0].stats.station
#     save_file = sta + "_{}_hk_stack.png".format(cha)
#     num = len(db_sta[cha])
#     plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha), save_file=save_file, show=show, num=num)
# pbar.close()