In [None]:
import os
import sys

In [None]:
import numpy as np
import scipy
import matplotlib.dates
import matplotlib.pyplot as plt
from collections import defaultdict
import datetime
import dateutil
from dateutil import rrule

In [None]:
package_root = os.path.abspath('../../..')
if package_root not in sys.path:
    sys.path.append(package_root)
from seismic.ASDFdatabase import FederatedASDFDataSet

In [None]:
import obspy
from analytic_plot_utils import distance
from netCDF4 import Dataset as NCDataset

In [None]:
# Imports for plotting
from textwrap import wrap
from scipy import signal

In [None]:
ds = FederatedASDFDataSet.FederatedASDFDataSet("/g/data/ha3/Passive/SHARED_DATA/Index/asdf_files.txt")

In [None]:
SRC_FILE = "/g/data/ha3/am7399/shared/xcorr/AU/ARMA_CMSA/AU.ARMA.AU.CMSA.nc"

In [None]:
TIME_WINDOW = 300 # +/-
SNR_THRESHOLD = 10

In [None]:
# Read xcorr data
xcdata = NCDataset(SRC_FILE, 'r')
print(xcdata)

xc_start_times = xcdata.variables['IntervalStartTimes'][:] # sTimes
xc_end_times = xcdata.variables['IntervalEndTimes'][:] # eTimes
xc_lag = xcdata.variables['lag'][:] # lag
xc_xcorr = xcdata.variables['xcorr'][:, :] # xcorr
xc_nStackedWindows = xcdata.variables['NumStackedWindows'][:] # nStackedWindows
xcdata.close()
xcdata = None

start_utc_time = obspy.UTCDateTime(xc_start_times[0])
end_utc_time = obspy.UTCDateTime(xc_end_times[-1])
print((start_utc_time, end_utc_time))

In [None]:
start_time = str(start_utc_time)
end_time = str(end_utc_time)
print((start_time, end_time))

In [None]:
# Get station codes from file name
def stationCodes(filename):
    path, fname = os.path.split(filename)
    parts = fname.split('.')
    sta1 = '.'.join(parts[0:2])
    sta2 = '.'.join(parts[2:4])
    return (sta1, sta2)

In [None]:
def stationCoords(federated_ds, code, datetime):
    ds = federated_ds
    net, sta = code.split('.')
    sta_records = ds.get_stations(datetime, obspy.UTCDateTime(datetime) + 3600, network=net, station=sta)
    z_records = [r for r in sta_records if r[3][1:3] == 'HZ']
    assert len(z_records) == 1
    z_record = z_records[0]
    return z_record[4:6]

In [None]:
def stationDistance(federated_ds, code1, code2, datetime):
    coords1 = stationCoords(federated_ds, code1, datetime)
    coords2 = stationCoords(federated_ds, code2, datetime)
    return distance(coords1, coords2)

In [None]:
origin_code, dest_code = stationCodes(SRC_FILE)

In [None]:
stationCoords(ds, origin_code, start_time)

In [None]:
stationCoords(ds, dest_code, start_time)

In [None]:
dist = stationDistance(ds, origin_code, dest_code, start_time)
print(dist)

In [None]:
# Extract primary data
lagIndices = np.squeeze(np.argwhere(np.fabs(np.round(xc_lag, decimals=2)) == TIME_WINDOW))
sTimes = xc_start_times
lag = xc_lag[lagIndices[0]:lagIndices[1]]
ccf = xc_xcorr[:, lagIndices[0]:lagIndices[1]]
nsw = xc_nStackedWindows

In [None]:
# Compute derived quantities used by multiple axes
zero_row_mask = (np.all(ccf == 0, axis=1))
valid_mask = np.ones_like(ccf)
valid_mask[zero_row_mask, :] = 0
valid_mask = (valid_mask > 0)
ccfMasked = np.ma.masked_array(ccf, mask=~valid_mask)
snr = np.nanmax(ccfMasked, axis=1) / np.nanstd(ccfMasked, axis=1)
if np.any(snr > SNR_THRESHOLD):
    rcf = np.nanmean(ccfMasked[(snr > SNR_THRESHOLD), :], axis=0)
else:
    rcf = None

In [None]:
def debugLabelAxes(ax, label):
    ax.text(0.5, 0.95, label, horizontalalignment='center', verticalalignment='top', transform=ax.transAxes, fontsize=20)

In [None]:
def timestampToPlottableDatetime(data):
    return data.transform(datetime.datetime.utcfromtimestamp).astype('datetime64[s]')

In [None]:
def plotXcorrTimeseries(ax, x_lag, y_times, xcorr_data):

    np_times = np.array([datetime.datetime.utcfromtimestamp(v) for v in sTimes]).astype('datetime64[s]')
    gx, gy = np.meshgrid(x_lag, np_times)
    im = ax.pcolormesh(gx, gy, xcorr_data, cmap='RdYlBu_r', vmin=0, vmax=1, rasterized=True)

    use_formatter = False
    if use_formatter:
        date_formatter = matplotlib.dates.DateFormatter("%Y-%m-%d")
        date_locator = matplotlib.dates.WeekdayLocator(byweekday=rrule.SU)
        ax.yaxis.set_major_formatter(date_formatter)
        ax.yaxis.set_major_locator(date_locator)
    else:
        labels = np.datetime_as_string(np_times, unit='D')
        ax.set_yticks(np_times[::7])
        ax.set_yticklabels(labels[::7])

    ax.set_xlabel('Lag [s]')
    ax.set_ylabel('Days')
    
    ax_pos = ax.get_position()
    cax = plt.axes([ax_pos.x0 + 0.025, ax_pos.y1 - 0.1, 0.015, 0.08])

    plt.colorbar(im, cax=cax, orientation='vertical', ticks=[0, 1])

In [None]:
# print(ccf.shape)
# print(type(ccf))
# print(sTimes.shape)
# plt.figure(figsize=(16,9))
# plt.subplot(311)
# plt.plot(sTimes, 'x')
# ccf_mean = np.mean(ccf, axis=1)
# print(ccf_mean.shape)
# plt.subplot(312)
# plt.plot(ccf_mean, '+')
# plt.subplot(313)
# plt.plot(sTimes, ccf_mean, 'v')

In [None]:
# print(xc_xcorr.shape)
# print(type(xc_xcorr))
# print(sTimes.shape)
# plt.figure(figsize=(16,9))
# plt.subplot(311)
# plt.plot(sTimes, 'x')
# xc_xcorr_mean = np.mean(xc_xcorr, axis=1)
# print(xc_xcorr_mean.shape)
# plt.subplot(312)
# plt.plot(xc_xcorr_mean, '+')
# plt.subplot(313)
# plt.plot(sTimes, xc_xcorr_mean, 'v')

In [None]:
def plotRCF(ax, rcf):
    if rcf is not None:
        ax.axvline(x_lag[np.argmax(rcf)], c='#c66da9', lw=2,
                    label = '{:5.2f} s'.format(x_lag[np.argmax(rcf)]))
        ax.plot(x_lag, rcf, c='#42b3f4', 
                label=r"Reference CCF \n"
                       "Based on Subset \n"
                       "with SNR > {}".format(SNR_THRESHOLD))
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'REFERENCE CCF:\nINSUFFICIENT SNR', horizontalalignment='center', 
         verticalalignment='center', transform=ax.transAxes, fontsize=20)

    ax.set_xticklabels([])

In [None]:
snr = np.nanmax(ccfMasked, axis=1) / np.nanstd(ccfMasked, axis=1)
plt.plot(snr)

In [None]:
fig = plt.figure(figsize=(18,32))
fig.suptitle("Station: {}, Dist. to {}: {:3.2f} km".format(origin_code, dest_code, dist), fontsize = 16, y=1)

ax1 = fig.add_axes([0.1, 0.075, 0.5, 0.725])
debugLabelAxes(ax1, 'ax1')

labelPad = 0.05
ax2 = fig.add_axes([0.1, 0.8, 0.5, 0.175]) # reference CCF (accumulation of daily CCFs)
ax3 = fig.add_axes([0.6, 0.075, 0.1, 0.725]) # number of stacked windows
ax4 = fig.add_axes([0.6 + labelPad, 0.8 + labelPad, 0.345, 0.175 - labelPad]) # histogram
ax5 = fig.add_axes([0.7, 0.075, 0.1, 0.725]) # Pearson coeff
ax6 = fig.add_axes([0.8, 0.075, 0.195, 0.725]) # estimate timeshifts
debugLabelAxes(ax2, 'ax2')
debugLabelAxes(ax3, 'ax3')
debugLabelAxes(ax4, 'ax4')
debugLabelAxes(ax5, 'ax5')
debugLabelAxes(ax6, 'ax6')

# Plot CCF image =======================
plotXcorrTimeseries(ax1, lag, sTimes, ccf)
# gx, gy = np.meshgrid(lag, sTimes)
# im = ax1.pcolormesh(gx, gy, ccf, cmap='RdYlBu_r', vmin=0, vmax=1, rasterized=True)

# t = np.array(sTimes)

# labels=[]
# for st in sTimes: 
#     labels.append(obspy.UTCDateTime(st).strftime("%y-%m-%d"))
# ax1.set_yticks(sTimes[::7])
# ax1.set_yticklabels(labels[::7])
# ax1.set_xlabel('Lag [s]')
# ax1.set_ylabel('Days')

# fig.colorbar(im, cax=cax1, orientation='vertical', ticks=[0, 1])

# Plot CCF-template =====================
plotRCF(ax2, rcf)
# rowMask = (np.sum(ccf, axis=1) > 0)
# mask = np.ones_like(ccf)
# for i in range(len(rowMask)):
#     mask[i,:] *= rowMask[i] 

# ccfMasked = np.ma.masked_array(ccf, mask=~np.bool_(mask))
# snr = np.nanmax(ccfMasked, axis=1) / np.nanstd(ccfMasked, axis=1)    

# rcf = np.nanmean(ccfMasked[snr>SNR_THRESHOLD, :], axis=0)
# ax2.axvline(lag[np.argmax(rcf)], c='#c66da9', lw=2,
#             label = '%5.2f s'%(lag[np.argmax(rcf)]))
# ax2.plot(lag, rcf, c='#42b3f4', 
#          label=r"Reference CCF "
#                 "\n"
#                 "Based on Subset "
#                 "\n"
#                 "with SNR > %d"%SNR_THRESHOLD)
# ax2.set_xticklabels([])
# ax2.legend()

# Plot number of stacked windows ==============
ax3.plot(nsw, sTimes, c='#529664')
ax3.set_yticklabels([])
ax3.set_xlabel('\n'.join(wrap('# of Hourly Stacked Windows', 12)))
xtl = ax3.get_xticklabels()
xtl[0].set_visible(False)
xtl[-1].set_visible(False)

# Plot histogram
ax4.hist(snr.compressed(), fc='#42b3f4', ec='none', bins=10)
ax4.set_xlabel('SNR: Daily CCFs [-%d, %d]s'%(TIME_WINDOW, TIME_WINDOW))
ax4.set_ylabel('Frequency')
xtl = ax4.get_xticklabels()
xtl[0].set_visible(False)
xtl[-1].set_visible(False)

# plot cc ===================
# Compute CCave
cc = []
for row in ccfMasked:
    if np.ma.is_masked(row):
        cc.append(0)
        continue
    elif rcf is not None:
        pcf, _ = scipy.stats.pearsonr(rcf, row)
        cc.append(pcf)
    else:
        cc.append(np.nan)
# end for
cc = np.array(cc)
ccav = np.mean(np.ma.masked_array(cc, mask=cc==0))

ax5.plot(cc, sTimes, c='#d37f26')
ax5.set_yticklabels([])
ax5.set_xticks([0,1])
ax5.set_xlabel('\n'.join(wrap('Pearson Coeff. (RCF * CCF)', 15)))
ax5.text(0.5, 0.95, '$CC_{ave}$=%3.3f'%ccav, horizontalalignment='center', \
         verticalalignment='center', transform=ax5.transAxes)

# plot Timeshift =====================    
corr = []
for i, row in enumerate(ccfMasked):
    if np.ma.is_masked(row): 
        corr.append(0)
        continue
        
    if rcf is None:
        corr.append(np.nan)
        continue

    if cc[i] < 0.85*ccav:
        corr.append(0)
        continue

    c3 = scipy.signal.correlate(rcf, row, mode='same')
    c3 /= np.max(c3)
    corr_lag = lag[np.argmax(c3)]
    corr.append(corr_lag)    
# end for
corr = np.array(corr)
ax6.plot(corr, sTimes, c='#f22e62', lw=1.5)
ax6.set_yticklabels([])
xtl = ax6.get_xticklabels()
xtl[0].set_visible(False)
xtl[-1].set_visible(False)
ax6.set_xlabel('\n'.join(wrap('Estimated Timeshift [s]: RCF * CCF', 15)))

plt.show()
plt.close()