In [None]:
import os

In [None]:
import numpy as np
import scipy
import obspy
import rf
import matplotlib.pyplot as plt

In [None]:
import seismic.receiver_fn.rf_plot_utils as rf_plot_utils

In [None]:
src_trace_file = "/g/data/ha3/am7399/rf_validation/7W_bilby/7W.BL05_event_waveforms_for_rf_filtered.h5"
src_data = obspy.read(src_trace_file, format='h5')

In [None]:
len(src_data)

In [None]:
src_data[0].stats

In [None]:
# Compare inbuilt Toeplitz solver with scipy Toeplitz solver

In [None]:
# Replicate deconvt from rf and replace solver with scipy Toeplitz solver

In [None]:
def toeplitz_solver_scipy(a, b):
    """
    Solve linear system Ax=b for real symmetric Toeplitz matrix A.

    :param a: first row of Toeplitz matrix A
    :param b: vector b
    :return: x=A^-1*b
    """
#     return sto_sl(np.hstack((a, a[1:])), b, job=0)
    return scipy.linalg.solve_toeplitz(a, b, check_finite=False)


def custom_deconvt(rsp_list, src, shift, spiking=1., length=None, normalize=0):
    """
    Time domain deconvolution.

    Deconvolve src from arrays in rsp_list.
    Calculate Toeplitz auto-correlation matrix of source, invert it, add noise
    and multiply it with cross-correlation vector of response and source.

    In one formula::

        RF = (STS + spiking*I)^-1 * STR

        N... length
            ( S0   S-1  S-2 ... S-N+1 )
            ( S1   S0   S-1 ... S-N+2 )
        S = ( S2   ...                )
            ( ...                     )
            ( SN-1 ...          S0    )
        R = (R0 R1 ... RN-1)^T
        RF = (RF0 RF1 ... RFN-1)^T
        S... source matrix (shape N*N)
        R... response vector (length N)
        RF... receiver function (deconvolution) vector (length N)
        STS = S^T*S = symmetric Toeplitz autocorrelation matrix
        STR = S^T*R = cross-correlation vector
        I... Identity

    :param rsp_list: either a list of arrays containing the response functions
        or a single array
    :param src: array of source function
    :param shift: shift the source by that amount of samples to the left side
        to get onset in RF at the desired time (negative -> shift source to the
        right side)\n
        shift = (middle of rsp window - middle of src window) +
        (0 - middle rf window)
    :param spiking: random noise added to autocorrelation (eg. 1.0, 0.1)
    :param length: number of data points in results
    :param normalize: normalize all results so that the maximum of the trace
        with supplied index is 1. Set normalize to None for no normalization.

    :return: (list of) array(s) with deconvolution(s)
    """
    if length is None:
        length = rf.deconvolve.__get_length(rsp_list)
    flag = False
    RF_list = []
    STS = rf.deconvolve._acorrt(src, length)
    STS = STS / STS[0]
    STS[0] += spiking
    if not isinstance(rsp_list, (list, tuple)):
        flag = True
        rsp_list = [rsp_list]
    for rsp in rsp_list:
        STR = rf.deconvolve._xcorrt(rsp, src, length, shift)
        assert len(STR) == len(STS)
        # Replaced solver here
        RF = toeplitz_solver_scipy(STS, STR)
        RF_list.append(RF)
    if normalize is not None:
        norm = 1 / np.max(np.abs(RF_list[normalize]))
        for RF in RF_list:
            RF *= norm
    if flag:
        return RF
    else:
        return RF_list


In [None]:
def iter3c(stream):
    return rf.util.IterMultipleComponents(stream, key='onset', number_components=(2, 3))

In [None]:
def __find_nearest(array, value):
    """http://stackoverflow.com/a/26026189"""
    idx = np.searchsorted(array, value, side='left')
    expr = np.abs(value - array[idx - 1]) < np.abs(value - array[idx])
    if idx > 0 and (idx == len(array) or expr):
        return idx - 1
    else:
        return idx

def deconvolve_main(stream, func=None, source_components='LZ', response_components=None,
                    winsrc=(-10, 30, 5), **kwargs):
    """Copy of rf.deconvolve.deconvolve() function, modified to allow custom time-domain deconvolution method.
    """
#     method = 'time'
    # identify source and response components
    src = [tr for tr in stream if tr.stats.channel[-1] in source_components]
    if len(src) != 1:
        msg = 'Invalid number of source components. %d not equal to one.'
        raise ValueError(msg % len(src))
    src = src[0]
    rsp = [tr for tr in stream if response_components is None or
           tr.stats.channel[-1] in response_components]
    if 'normalize' not in kwargs and src in rsp:
        kwargs['normalize'] = rsp.index(src)
    if not 0 < len(rsp) < 4:
        msg = 'Invalid number of response components. %d not between 0 and 4.'
        raise ValueError(msg % len(rsp))

    sr = src.stats.sampling_rate
    # shift onset to time of nearest data sample to circumvent complications
    # for data with low sampling rate and method='time'
    idx = __find_nearest(src.times(), src.stats.onset - src.stats.starttime)
    src.stats.onset = onset = src.stats.starttime + idx * src.stats.delta
    for tr in rsp:
        tr.stats.onset = onset
    # define default time windows
    lenrsp_sec = src.stats.endtime - src.stats.starttime
    onset_sec = onset - src.stats.starttime

    # prepare source and response list
    if src in rsp:
        src = src.copy()
    src.trim(onset + winsrc[0], onset + winsrc[1], pad=True, fill_value=0.)
    src.taper(max_percentage=None, max_length=winsrc[2])
    rsp_data = [tr.data for tr in rsp]
    tshift = -winsrc[0]
    shift = int(round(tshift * sr - len(src) // 2))
    if func is None:
        # Use rf library
        rf_data = rf.deconvolve.deconvt(rsp_data, src.data, shift,  **kwargs)
    else:
        rf_data = func(rsp_data, src.data, shift,  **kwargs)

    for i, tr in enumerate(rsp):
        tr.data = rf_data[i].real

    return stream.__class__(rsp)


In [None]:
data_orig = rf.RFStream(src_data)

In [None]:
filter_args = {'type': 'bandpass', 'freqmin': 0.08, 'freqmax': 0.6, 'corners': 2, 'zerophase': True}

In [None]:
rf_orig = data_orig.copy().rf(filter=filter_args, rotate='NE->RT', trim=(-10, 30))

In [None]:
type(rf_orig)

In [None]:
fig_orig = rf_plot_utils.plot_rf_stack(rf_orig.select(component='R'), time_window=(-10, 30))

In [None]:
rf_scipy = data_orig.copy()
# method = 'P'
for stream3c in iter3c(rf_scipy):
    stream3c.filter(**filter_args)
    stream3c.rotate('NE->RT')
    response = deconvolve_main(stream3c, func=custom_deconvt, source_components='Z', winsrc=(-10, 30, 5))
    stream3c.traces = response.traces

In [None]:
type(rf_scipy)

In [None]:
fig_scipy = rf_plot_utils.plot_rf_stack(rf_scipy.select(component='R'), time_window=(-10, 30))

In [None]:
rf_orig_R = rf_orig.select(component='R')

In [None]:
rf_scipy_R = rf_scipy.select(component='R').copy().trim2(-10, 30, reftime='onset')

In [None]:
len(rf_orig_R)

In [None]:
for tr_idx in range(len(rf_orig_R)):
    fig = plt.figure(figsize=(12,4))
    tr_orig = rf_orig_R[tr_idx]
    tr_scipy = rf_scipy_R[tr_idx]
    plt.plot(tr_orig.times() - (tr_orig.stats.onset - tr_orig.stats.starttime), tr_orig.data)
    plt.plot(tr_scipy.times() - (tr_scipy.stats.onset - tr_scipy.stats.starttime), tr_scipy.data)
    plt.ylim((-0.3, 0.7))
    plt.grid()
    plt.title("n={}".format(tr_idx))
    plt.show()