In [None]:
import os
from collections import defaultdict
import time

import numpy as np
import rf
import rf.imaging
import matplotlib.pyplot as plt
import scipy
from scipy import signal
from scipy.interpolate import interp1d
import obspy

from tqdm.auto import tqdm

In [None]:
src_file = "/g/data/ha3/am7399/shared/OA_event_waveforms_for_rf_20170911T000036-20181128T230620_ZRT_td_rev3.h5"

In [None]:
# oa_all = rf.read_rf(src_file, format='h5', group='/waveforms/OA.BT23.0M')
oa_all = rf.read_rf(src_file, format='h5')

In [None]:
def rf_to_dict(rf_data):
    db = defaultdict(lambda: defaultdict(list))
    for s in rf_data:
        _, sta, _, cha = s.id.split('.')
        db[sta][cha].append(s)
    return db

In [None]:
def plot_station_rf_overlays(db_station):
    plt.figure(figsize=(16,24))
    colors = ["#8080a040", "#80a08040", "#a0808040"]
    min_x = 1e+20
    max_x = -1e20
    for i, (ch, streams) in enumerate(db_station.items()):
        col = colors[i]
        plt.subplot(3, 1, i + 1)
        sta = streams[0].stats.station
        plt.title('.'.join([sta, ch]), fontsize=14)
        for j, s in enumerate(streams):
            lead_time = s.stats.onset - s.stats.starttime
            times = s.times()
            plt.plot(times - lead_time, s.data, '--', color=col, linewidth=2)
            if j == 0:
                data_mean = s.data
            else:
                data_mean += s.data
        data_mean /= float(j)
        plt.plot(s.times() - lead_time, data_mean, color="#202020", linewidth=2)
        plt.xlabel('Time (s)')
        plt.ylabel('Amplitude (normalized)')
        plt.grid(linestyle=':', color="#80808020")
        x_lims = plt.xlim()
        min_x = min(min_x, x_lims[0])
        max_x = max(max_x, x_lims[1])
    for i in range(3):
        subfig = plt.subplot(3, 1, i + 1)
        subfig.set_xlim((min_x, max_x))

In [None]:
def signed_nth_root(arr, order):
    if order == 1:
        return arr
    else:
        return np.sign(arr)*np.power(np.abs(arr), 1.0/order)

In [None]:
def signed_nth_power(arr, order):
    if order == 1:
        return arr
    else:
        return np.sign(arr)*np.power(np.abs(arr), order)

In [None]:
def compute_hk_stack(db_station, cha, h_range=np.linspace(10.0, 70.0, 301), k_range = np.linspace(1.3, 2.1, 201),
                     V_p = 6.4, root_order=1, include_t3=True):

    # Pre-compute grid quantities
    k_grid, h_grid = np.meshgrid(k_range, h_range)
    hk_stack = np.zeros_like(k_grid)
    H_on_V_p = h_grid/V_p
    k2 = k_grid*k_grid

    stream_stack = []
    cha_data = db_station[cha]
    # Loop over streams, compute times, and stack interpolated values at those times
    for s in cha_data:
        incidence = s.stats.inclination
        incidence_rad = incidence*np.pi/180.0
        cos_i, sin_i = np.cos(incidence_rad), np.sin(incidence_rad)
        sin2_i = sin_i*sin_i
        term1 = H_on_V_p*k_grid*np.abs(cos_i)
        term2 = H_on_V_p*np.sqrt(1 - k2*sin2_i)
        # Time for Ps
        t1 = term1 - term2
        # Time for PpPs
        t2 = term1 + term2
        if include_t3:
            # Time for PpSs + PsPs
            t3 = 2*term1

        # Subtract lead time so that primary P-wave arrival is at time zero.
        lead_time = s.stats.onset - s.stats.starttime
        times = s.times() - lead_time
        # Create interpolator from stream signal for accurate time sampling.
        interpolator = interp1d(times, s.data, kind='linear', copy=False, bounds_error=False, assume_sorted=True)

        phase_sum = []
        phase_sum.append(signed_nth_root(interpolator(t1), root_order))
        phase_sum.append(signed_nth_root(interpolator(t2), root_order))
        if include_t3:
            # Negative sign on the third term is intentional, see Chen et al. (2010) and Zhu & Kanamori (2000).
            # It needs to be negative because the PpSs + PsPs peak has negative phase,
            # see http://eqseis.geosc.psu.edu/~cammon/HTML/RftnDocs/rftn01.html
            # Apply nth root technique to reduce uncorrelated noise (Chen et al. (2010))
            phase_sum.append(-signed_nth_root(interpolator(t3), root_order))

        stream_stack.append(phase_sum)

    # Perform the stacking (sum) across streams. hk_stack retains separate t1, t2, and t3 components here.
    hk_stack = np.nanmean(np.array(stream_stack), axis=0)

    # This inversion of the nth root is different to Sippl and Chen, but consistent with Muirhead
    # who proposed the nth root technique. It improves the contrast of the resulting plot.
    if root_order != 1:
        hk_stack = signed_nth_power(hk_stack, root_order)

    return k_grid, h_grid, hk_stack

In [None]:
def compute_weighted_stack(hk_components, weighting=(0.5, 0.5, 0.0)):
    assert hk_components.shape[0] == len(weighting), hk_components.shape
    hk_phase_stacked = np.dot(np.moveaxis(hk_components, 0, -1), np.array(weighting))
    return hk_phase_stacked

In [None]:
def plot_hk_stack(k_grid, h_grid, hk_stack, title=None, save_file=None, show=True, num=None, clip_negative=True):
    # Call computed_weighted_stack() first to combine weighted components before calling this function.
    # Use a perceptually linear color map.
    colmap = 'plasma'
    plt.figure(figsize=(16, 12))
    if clip_negative:
        hk_stack[hk_stack < 0] = 0
    plt.contourf(k_grid, h_grid, hk_stack, levels=50, cmap=colmap)
    cb = plt.colorbar()
    cb.ax.set_ylabel('Stack sum')
    plt.contour(k_grid, h_grid, hk_stack, levels=10, colors='k', linewidths=1)
    plt.xlabel(r'$\kappa = \frac{V_p}{V_s}$ (ratio)', fontsize=14)
    plt.ylabel('H = Moho depth (km)', fontsize=14)
    if title is not None:
        plt.title(title, fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    if num is not None:
        xl = plt.xlim()
        yl = plt.ylim()
        txt_x = xl[0] + 0.85*(xl[1] - xl[0])
        txt_y = yl[0] + 0.95*(yl[1] - yl[0])
        plt.text(txt_x, txt_y, "N = {}".format(num), color="#ffffff", fontsize=16, fontweight='bold')

    if save_file is not None:
        tries = 10
        while tries > 0:
            try:
                tries -= 1
                plt.savefig(save_file, dpi=300)
                break
            except PermissionError:
                time.sleep(1)
                if tries == 0:
                    print("WARNING: Failed to save file {} due to permissions!".format(save_file))
                    break
            # end try
        # end while

    if show:
        plt.show()
    else:
        plt.close()

In [None]:
def filter_station_streams(db_station, freq_band=(None, None)):
    """Perform frequency filtering on streams. Returns a replica of db_station with streams containing filtered results.
    """
    db_station_filt = defaultdict(list)
    for i, (ch, streams) in enumerate(db_station.items()):
        if ch == 'size' or i >= 3:
            continue
        for j, s in enumerate(streams):
            stream_filt = s.copy()
            if freq_band[0] is None and freq_band[1] is not None:
                stream_filt.filter('lowpass', zerophase=True, corners=2, freq=freq_band[1])
            elif freq_band[0] is not None and freq_band[1] is None:
                stream_filt.filter('highpass', zerophase=True, corners=2, freq=freq_band[0])
            elif freq_band[0] is not None and freq_band[1] is not None:
                stream_filt.filter('bandpass', zerophase=True, corners=2, freqmin=freq_band[0],
                                   freqmax=freq_band[1])
            # end if
            db_station_filt[ch].append(stream_filt)
        # end for
    # end for

    return db_station_filt

In [None]:
def filter_station_to_mean_signal(db_station, min_correlation=1.0):
    """Filter out streams which are not 'close enough' to the mean signal,
       based on simple correlation score.
    """
    # Compute mean signals of channels in station
    mean_rfs = []
    for i, (ch, streams) in enumerate(db_station.items()):
        if ch == 'size' or i >= 3:
            continue
        for j, s in enumerate(streams):
            if j == 0:
                data_mean = s.data
            else:
                data_mean += s.data
            # end if
        # end for
        data_mean /= float(j)
        mean_rfs.append(data_mean)
    # end for

    # Filter out signals that do not meet minimum coherence with mean signal for each channel
    db_station_filt = defaultdict(list)

    for i, (ch, streams) in enumerate(db_station.items()):
        for j, s in enumerate(streams):
            corr = np.dot(s.data, mean_rfs[i])/np.dot(mean_rfs[i], mean_rfs[i])
            if corr >= min_correlation:
                db_station_filt[ch].append(s)
        # end for
    # end for
                
    return db_station_filt

# Select test station and plot overlaid stream data per channel

In [None]:
type(oa_all)

In [None]:
db = rf_to_dict(oa_all)

In [None]:
# import pandas as pd
# sta_codes = list(db.keys())
# sta_lat = [db[sta]['HHR'][0].stats.station_latitude for sta in sta_codes]
# sta_lon = [db[sta]['HHR'][0].stats.station_longitude for sta in sta_codes]
# df_OA_sta = pd.DataFrame.from_dict({'Station': sta_codes, 'Latitude': sta_lat, 'Longitude': sta_lon})
# df_OA_sta.to_csv('OA_stations.txt', index=False)

In [None]:
# # Get coordinates of select stations for north-south lines
# for sta in ['BV21', 'BV28', 'BX20', 'BX28', 'BY20', 'BY28', 'CB20', 'CB28', 'CD21', 'CD28']:
#     lat = db[sta]['HHR'][0].stats.station_latitude
#     lon = db[sta]['HHR'][0].stats.station_longitude
#     print("{}: ({} {})".format(sta, lat, lon))

In [None]:
# # Get coordinates of select stations for west-east lines
# for sta in ['BU22', 'CI22', 'BS26', 'CE26', 'BS28', 'CF28']:
#     lat = db[sta]['HHR'][0].stats.station_latitude
#     lon = db[sta]['HHR'][0].stats.station_longitude
#     print("{}: ({} {})".format(sta, lat, lon))

In [None]:
# # Get coordinates of select stations for Andrew Clark
# for i, sta_pair in enumerate([('BV21', 'CC28'), ('BX20', 'BX28'), ('BY20', 'BY28')]):
#     if i == 0:
#         # NE-->SW line
#         lat0 = db[sta_pair[0]]['HHR'][0].stats.station_latitude + 0.1
#         lon0 = db[sta_pair[0]]['HHR'][0].stats.station_longitude - 0.1
#         lat1 = db[sta_pair[1]]['HHR'][0].stats.station_latitude - 0.1
#         lon1 = db[sta_pair[1]]['HHR'][0].stats.station_longitude + 0.1
#     else:
#         # N-->S line
#         lat0 = db[sta_pair[0]]['HHR'][0].stats.station_latitude + 0.1
#         lon0 = db[sta_pair[0]]['HHR'][0].stats.station_longitude
#         lat1 = db[sta_pair[1]]['HHR'][0].stats.station_latitude - 0.1
#         lon1 = db[sta_pair[1]]['HHR'][0].stats.station_longitude
#     print("{}-{}: --start-latlon {} {} --end-latlon {} {}".format(sta_pair[0], sta_pair[1], lat0, lon0, lat1, lon1))

In [None]:
oa_test = db['BT23']

In [None]:
type(oa_test['HHR'][0])

In [None]:
# tr_copy = oa_test['HHR'][0].copy()
# tr_copy.filter('lowpass', zerophase=True, corners=2, freq=5.0)
# tr_copy.plot()

In [None]:
oa_test['HHR'][0].stats

In [None]:
plot_station_rf_overlays(oa_test)

In [None]:
oa_test_filt = filter_station_streams(oa_test, freq_band=(0.25, None))
oa_test_filt = filter_station_to_mean_signal(oa_test_filt, min_correlation=1.1)

In [None]:
plot_station_rf_overlays(oa_test_filt)

In [None]:
# Test phase weighting
s0 = oa_test_filt['HHR'][0]
time_offset = s0.stats.onset - s0.stats.starttime
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, s0.data)
plt.show()

In [None]:
s0_h = signal.hilbert(s0)

In [None]:
np.imag(s0_h)

In [None]:
s0_angle = np.angle(s0_h)

In [None]:
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, np.unwrap(s0_angle))
plt.show()

In [None]:
s0_iphase = np.exp(1j*s0_angle)

In [None]:
angles = []
tphase = []
for s in oa_test_filt['HHR']:
    analytic = signal.hilbert(s.data)
    angle = np.angle(analytic)
    angles.append(np.unwrap(angle))
    iPhase = np.exp(1j*angle)
    tphase.append(iPhase)
tphase = np.array(tphase)
tphase_mean = np.abs(np.mean(tphase,axis=0))
tphase_norm = tphase_mean/np.max(tphase_mean)

In [None]:
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, np.array(angles).T)
plt.show()

In [None]:
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, tphase_norm)
plt.show()

In [None]:
# Demonstrate effect of phase weighting to suppress areas where phases tend to be random.
plt.figure(figsize=(16,9))
plt.plot(s0.times() - time_offset, s0.data, linewidth=2)
plt.plot(s0.times() - time_offset, s0.data*tphase_norm, '--', linewidth=2)
plt.legend(['Original', 'Phase weighted'])
plt.grid()
plt.show()

# Run over all channels to verify on HHR is the correct channel to plot

In [None]:
# Plot stacks for all channels
weighting = (0.4, 0.4, 0.2)
# for cha in ['HHR', 'HHT', 'HHZ']:
for cha in ['HHR']:
    k_grid, h_grid, hk_stack = compute_hk_stack(oa_test, cha, root_order=2)

    hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
    
    sta = oa_test[cha][0].stats.station

    num = len(oa_test[cha])
    save_file = None
#     plot_hk_stack(k_grid, h_grid, hk_stack[0], title=sta + '.{} Ps'.format(cha), num=num)
#     plot_hk_stack(k_grid, h_grid, hk_stack[1], title=sta + '.{} PpPs'.format(cha), num=num)
#     plot_hk_stack(k_grid, h_grid, hk_stack[2], title=sta + '.{} PpSs + PsPs'.format(cha), num=num)
    plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha), num=num, save_file=save_file)

In [None]:
# Plot stacks for all channels in filtered data
weighting = (0.4, 0.4, 0.2)
for cha in ['HHR']:
    k_grid, h_grid, hk_stack = compute_hk_stack(oa_test_filt, cha, root_order=2)

    hk_stack_sum = compute_weighted_stack(hk_stack, weighting)

    sta = oa_test_filt[cha][0].stats.station

    num = len(oa_test_filt[cha])
    save_file = None
    plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha) + " (filtered)", num=num, save_file=save_file)

# Loop over all OA stations and plot HK-stacks

In [None]:
cha = 'HHR'
pbar = tqdm(total=len(db))
show = False
weighting = (0.5, 0.4, 0.1)
for sta, db_sta in db.items():
    pbar.set_description(sta)
    pbar.update()
    k_grid, h_grid, hk_stack = compute_hk_stack(db_sta, cha, root_order=2)
    hk_stack_sum = compute_weighted_stack(hk_stack, weighting)
    sta = db_sta[cha][0].stats.station
    save_file = sta + "_{}_hk_stack.png".format(cha)
    num = len(db_sta[cha])
    plot_hk_stack(k_grid, h_grid, hk_stack_sum, title=sta + '.{}'.format(cha), save_file=save_file, show=show, num=num)
pbar.close()

## Plot histogram of group ID distribution

In [None]:
groups = []
for sta, db_sta in db.items():
    ch_data = db_sta['HHR']
    for s in ch_data:
        groups.append(s.stats.rf_group)

In [None]:
plt.hist(groups)
plt.show()

## Plot histogram of inclinations

In [None]:
incs = []
for sta, db_sta in db.items():
    ch_data = db_sta['HHR']
    for s in ch_data:
        incs.append(s.stats.inclination)

In [None]:
plt.hist(incs)
plt.show()