# Test iterative deconvolution on 7W (Bilby) station

In [None]:
import os
import logging

import numpy as np
import matplotlib.pyplot as plt

import obspy
import rf

In [None]:
from seismic.receiver_fn.rf_plot_utils import plot_rf_stack
from seismic.receiver_fn.rf_deconvolution import iter_deconv_pulsetrain

In [None]:
# Load test data from Bilby. This is broadband data sampled at 50 Hz, so it is downsampled first before passing
# to iterative deconvolution.
src_trace_file = "../DATA/7W.BL05_event_waveforms_for_rf_filtered.h5"
src_data = obspy.read(src_trace_file, format='h5')

In [None]:
log = logging.getLogger(__name__)
log.setLevel(logging.ERROR)

In [None]:
# Run deconv on traces associated with same events
max_iterations = 200
time_window = (-10, 30)
i = 0
all_traces = []
freq_cutoff = (0.25, 1.0)
outfolder = 'test_iterdeconv_freq{:.2f}-{:.2f}'.format(*freq_cutoff)

if not os.path.exists(outfolder):
    os.makedirs(outfolder)
# end if

for stream3c in rf.util.IterMultipleComponents(src_data.copy(), key='onset', number_components=(2, 3)):
    stream3c.filter('bandpass', freqmin=freq_cutoff[0], freqmax=freq_cutoff[1], corners=2,
                    zerophase=True).interpolate(5.0)
    rf_stream = rf.RFStream(stream3c)
    rf_stream.rotate('NE->RT')
    rf_stream.trim2(*time_window, reftime='onset')
    rf_stream.detrend('linear')
    rf_stream.taper(0.2, max_length=0.5)
    source = rf_stream.select(component='Z')[0]
    response = rf_stream.select(component='R')[0]
    F_s = source.stats.sampling_rate
    time_shift = source.stats.onset - source.stats.starttime

    rf_trace, pulses, expected_response, predicted_response, fit = iter_deconv_pulsetrain(response, source, F_s, time_shift,
                                                                                          max_iterations, gwidth=2.5, log=log)
    # Normalize RF
    sum_sq = np.sum(np.square(rf_trace))
    rf_trace /= np.sqrt(sum_sq)
    tr = response.copy()
    tr.data = rf_trace.copy()
    all_traces.append(tr)

    # Generate plots
    event_id = source.stats.event_id
    times = source.times() - (source.stats.onset - source.stats.starttime)
    plt.figure(figsize=(12, 8))
    plt.subplot(211)
    plt.plot(times, expected_response, alpha=0.8)
    plt.plot(times, predicted_response, alpha=0.8)
    plt.xlabel("Time (s)")
    plt.ylabel("Radial amplitude")
    plt.text(0.02, 0.07, "Input filter band: ({:.2f}, {:.2f}) Hz".format(*freq_cutoff),
                fontsize=8, transform=plt.gca().transAxes, color="#404040")
    plt.text(0.02, 0.02, "Prediction match to observation: {:.2f}%".format(fit), fontsize=8,
                transform=plt.gca().transAxes, color="#404040")
    plt.grid("#80808080", linestyle=':')
    plt.legend(['Expected R-component', 'Predicted by RF'])
    plt.title("Event {} observed vs predicted Radial component".format(event_id))

    plt.subplot(212)
    plt.plot(times, rf_trace)
    plt.xlabel("Time (s)")
    plt.ylabel("RF amplitude (arb. units)")
    plt.grid("#80808080", linestyle=':')
    plt.legend(['Computed RF'])
    plt.title("Estimated Receiver Function", y=0.9)

    plt.savefig(os.path.join(outfolder, '{:03d}.png'.format(i)), dpi=300)
    plt.show()
    plt.close()

    i += 1
# end for

all_rf_stream = rf.RFStream(all_traces).sort(keys=['back_azimuth'])
stack_file = os.path.join(outfolder, 'rf_stack.png')
plot_rf_stack(all_rf_stream, time_window=time_window, trace_height=0.12, save_file=stack_file, dpi=300)
# plot_rf_stack(all_rf_stream, time_window=time_window, trace_height=0.12)
plt.show()
plt.close()