In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import windows
from phw_lib import *
import obspy

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
import tensorflow as tf

DL_model_raw = tf.keras.models.load_model('../phasor_walkout_DL_classification_model.h5')
DL_model_bandpass = tf.keras.models.load_model('../phasor_walkout_DL_classification_model_bandpass.h5')
DL_model_PAC = tf.keras.models.load_model('../phasor_walkout_DL_classification_model_PAC.h5')

def get_DL_prob(xi, dt, f_target, model):
    phasors, _ = phasor_walkout(xi, dt, f_target)
    cumulative_sums = np.cumsum(phasors)
    input_length = len(cumulative_sums)
    input_data = np.zeros((1, input_length, 4))
    input_data[0, :, 0] = phasors.real
    input_data[0, :, 1] = phasors.imag
    input_data[0, :, 2] = cumulative_sums.real
    input_data[0, :, 3] = cumulative_sums.imag
    # normalize the input data
    # normalize the input
    input_data[0, :, :] = input_data[0, :, :]/np.max(np.abs(input_data[0, :, :]))
    # predict the probability
    prob = model.predict(input_data)[0]
    return prob

2025-07-08 13:05:51.936280: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-07-08 13:05:52.634073: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22302 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:d1:00.0, compute capability: 8.6


In [3]:
event_dict = {}

event_dict['M6.5_VANUATU_ISLANDS_20140101'] = dict()
event_dict['M6.5_VANUATU_ISLANDS_20140101']['time'] = '2014-01-01T16:03:29.770000Z'
event_dict['M6.5_VANUATU_ISLANDS_20140101']['name'] = 'M6.5_VANUATU_20140101'

event_dict['M6.6_NORTH_OF_ASCENSION_ISLAND_20170818'] = dict()
event_dict['M6.6_NORTH_OF_ASCENSION_ISLAND_20170818']['time'] = '2017-08-18T02:59:21.950000Z'
event_dict['M6.6_NORTH_OF_ASCENSION_ISLAND_20170818']['name'] = 'M6.6_ASCENSION_20170818'

event_dict['M7.2_OAXACA_MEXICO_20180216'] = dict()
event_dict['M7.2_OAXACA_MEXICO_20180216']['time'] = '2018-02-16T23:39:39.700000Z'
event_dict['M7.2_OAXACA_MEXICO_20180216']['name'] = 'M7.2_OAXACA_20180216'

event_dict['M7.5_SOUTHEASTERN_ALASKA_20130105'] = dict()
event_dict['M7.5_SOUTHEASTERN_ALASKA_20130105']['time'] = '2013-01-05T08:58:19.180000Z'
event_dict['M7.5_SOUTHEASTERN_ALASKA_20130105']['name'] = 'M7.5_ALASKA_20130105'

event_dict['M8.1_EAST_OF_KURIL_ISLANDS_20070113'] = dict()
event_dict['M8.1_EAST_OF_KURIL_ISLANDS_20070113']['time'] = '2007-01-13T04:23:23.250000Z'
event_dict['M8.1_EAST_OF_KURIL_ISLANDS_20070113']['name'] = 'M8.1_KURIL_20070113'

event_dict['M8.1_NEAR_COAST_OF_CHIAPAS_MEXICO_20170908'] = dict()
event_dict['M8.1_NEAR_COAST_OF_CHIAPAS_MEXICO_20170908']['time'] = '2017-09-08T04:49:20.000000Z'
event_dict['M8.1_NEAR_COAST_OF_CHIAPAS_MEXICO_20170908']['name'] = 'M8.1_CHIAPAS_20170908'

event_dict['JP_M9.1'] = dict()
event_dict['JP_M9.1']['time'] = '2011-03-11T05:46:24'
event_dict['JP_M9.1']['name'] = 'M9.1_TOHOKU_20110311'

In [4]:
spherical_model_dict = {}
spherical_model_dict['0S6'] = 1.03755*1e-3
spherical_model_dict['0S9'] = 1.57737*1e-3
spherical_model_dict['0S13'] = 2.11155*1e-3
spherical_model_dict['0S14'] = 2.22960*1e-3
spherical_model_dict['0S15'] = 2.34475*1e-3
spherical_model_dict['0S16'] = 2.45680*1e-3

In [5]:
power_spectrum_dict = {}

plot_keys = list(event_dict.keys())

res_dict = dict()

test_time_length_in_sec = 86400*3

for mode_key in spherical_model_dict.keys():
    print('On mode: {}'.format(mode_key))
    check_spherical_mode_name = mode_key
    check_spherical_mode_frequency = spherical_model_dict[mode_key]
    min_freq_base = check_spherical_mode_frequency - 0.04*1e-3
    max_freq_base = check_spherical_mode_frequency + 0.04*1e-3
    delta_freq_expand = 0.2*1e-3 # for visualization

    total_plot_num = len(plot_keys)

    plt.figure(figsize=(24, 4*total_plot_num))

    for key in plot_keys:
        net = 'G'
        sta = 'CAN'
        event_name = key

        print('On event: {} Time Length: {}'.format(key, test_time_length_in_sec))
        st = obspy.read('./pwr_realworld_no_near_earthquakes/{}/{}.{}.After.mseed'.format(event_name, net, sta))
        # select BHZ channel
        st_bhz = st.select(channel='BHZ')
        st_bhz.trim(st_bhz[0].stats.starttime, st_bhz[0].stats.starttime + test_time_length_in_sec)
        
        # resampe to 0.5 Hz
        st_bhz.resample(0.5, window='hann')
        # demean delinear trend remove response
        st_bhz.detrend('linear')
        st_bhz.detrend('demean')
        #st_bhz.remove_response(output='ACC', pre_filt=(0.00001, 0.00001, 0.2, 0.22), inventory=inv)

        # fft transform
        npts = st_bhz[0].stats.npts
        dt = st_bhz[0].stats.delta
        f = np.fft.rfftfreq(npts, dt)

        ori_data = st_bhz[0].data.copy()

        after_data = ori_data
        
        windowed_data = ori_data * np.hanning(npts)
        Y = np.fft.rfft(windowed_data)
        Y_abs = np.abs(Y)**2

        min_freq = min_freq_base - delta_freq_expand
        max_freq = max_freq_base + delta_freq_expand

        fs = 1/dt
        time_duration = st[0].stats.endtime - st[0].stats.starttime
        n_samples = int(time_duration/dt)
        t = np.arange(0, time_duration, dt)
        freqs = np.fft.fftfreq(len(t), dt)

        search_lower_range = min_freq + 0.2*1e-3
        search_upper_range = max_freq - 0.2*1e-3

        lower_freq = min_freq
        upper_freq = max_freq
        c_map_code = 'viridis'

        min_freq_index = np.argmin(np.abs(f - min_freq))
        max_freq_index = np.argmin(np.abs(f - max_freq))

        fft_amp = Y_abs

        peaks, _ = find_peaks(fft_amp, distance=1)

        peaks_idx = np.argmin(np.abs(f[peaks] - check_spherical_mode_frequency))
        peaks = [ peaks[peaks_idx] ]
        after_freq = f[peaks[0]]

        min_freq_index = np.argmin(np.abs(f - min_freq))
        max_freq_index = np.argmin(np.abs(f - max_freq))
        cur_plot_idx = plot_keys.index(key)*4 + 1

        plt.subplot(total_plot_num, 4, cur_plot_idx)
        if cur_plot_idx == 1:
            plt.text(-0.03, 1.08, '(a)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 5:
            plt.text(-0.03, 1.08, '(e)', transform=plt.gca().transAxes, size=20, weight='bold')            
        elif cur_plot_idx == 9:
            plt.text(-0.03, 1.08, '(i)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 13:
            plt.text(-0.03, 1.08, '(m)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 17:
            plt.text(-0.03, 1.08, '(q)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 21:
            plt.text(-0.03, 1.08, '(u)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 25:
            plt.text(-0.03, 1.08, '(y)', transform=plt.gca().transAxes, size=20, weight='bold')

        plt.plot(f, fft_amp, color='orange', label='After')


        peak_amps = fft_amp[peaks]
        plt.xlim([min_freq, max_freq])
        min_1 = np.min(fft_amp[peaks])*1e-2
        max_1 = np.max(fft_amp[peaks])*1e1

        plt.scatter(f[peaks], peak_amps, c='r', s=100) #  label='Detected Peaks'
        """
        for t_peak in peaks:
            target_freq = f[t_peak]
            R = calculate_walkout_R2(ori_data, dt, target_freq)
            sum_ratio = linearity_measure_by_sum_ratio(ori_data, dt, target_freq)
            schuster_significance, ss_mag = schuster_test_for_phasor_walkout(ori_data, dt, target_freq)
            #DL_prob = get_DL_prob(ori_data, dt, target_freq, DL_model_raw)
            #plt.text(target_freq, fft_amp[t_peak], '{:.8f} mHz\nR$^2$: {:.3f}\nSchuster: {:.3f}\nCIPSR: {:.3f}'.format(target_freq*1e3, R, schuster_significance, sum_ratio), fontsize=14, weight='bold', ha='center', va='bottom')
        """
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power')
        plt.yscale('log')
        # use scientific notation for y axis
        #plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
        ev_time = event_dict[key]['time']
        plt.title('Power spectrum of event {} {}'.format(event_name, ev_time))

        plt.axvline(check_spherical_mode_frequency, color='k', linestyle='--')
        plt.text(check_spherical_mode_frequency, np.max(Y_abs[min_freq_index:max_freq_index])*0.6, check_spherical_mode_name, fontsize=16, weight='bold', ha='center', va='bottom')

        st = obspy.read('./pwr_realworld_no_near_earthquakes/{}/{}.{}.Before.mseed'.format(event_name, net, sta))

        # select BHZ channel
        st_bhz = st.select(channel='BHZ')
        st_bhz.trim(st_bhz[0].stats.starttime, st_bhz[0].stats.starttime + test_time_length_in_sec)

        # resampe to 0.5 Hz
        st_bhz.resample(0.5, window='hann')

        # demean delinear trend remove response
        st_bhz.detrend('demean')
        st_bhz.detrend('linear')

        # fft transform
        npts = st_bhz[0].stats.npts
        dt = st_bhz[0].stats.delta
        f = np.fft.rfftfreq(npts, dt)

        ori_data = st_bhz[0].data.copy()
        before_data = ori_data

        windowed_data = ori_data * np.hanning(npts)
        Y = np.fft.rfft(windowed_data)
        Y_abs = np.abs(Y)**2

        min_freq = min_freq_base - delta_freq_expand
        max_freq = max_freq_base + delta_freq_expand

        fs = 1/dt
        time_duration = st[0].stats.endtime - st[0].stats.starttime
        n_samples = int(time_duration/dt)
        t = np.arange(0, time_duration, dt)
        freqs = np.fft.fftfreq(len(t), dt)

        search_lower_range = min_freq + 0.2*1e-3
        search_upper_range = max_freq - 0.2*1e-3

        lower_freq = min_freq
        upper_freq = max_freq
        c_map_code = 'viridis'

        min_freq_index = np.argmin(np.abs(f - min_freq))
        max_freq_index = np.argmin(np.abs(f - max_freq))

        fft_amp = Y_abs

        peaks, _ = find_peaks(fft_amp, distance=1)

        peaks_idx = np.argmin(np.abs(f[peaks] - check_spherical_mode_frequency))
        peaks = [ peaks[peaks_idx] ]
        before_freq = f[peaks[0]]

        min_freq_index = np.argmin(np.abs(f - min_freq))
        max_freq_index = np.argmin(np.abs(f - max_freq))
        plt.plot(f, fft_amp, color='royalblue', label='Before')
        plt.legend(loc='upper right', fontsize=18)

        peak_amps = fft_amp[peaks]
        min_2 = np.min(fft_amp[peaks])*1e-2
        max_2 = np.max(fft_amp[peaks])*1e1

        plt.ylim([min(min_1, min_2), max(max_1, max_2)*5])
        # use log scale for y axis
        plt.yscale('log')
        plt.scatter(f[peaks], peak_amps, c='r', s=100) # 'Detected Peaks'
        """
        for t_peak in peaks:
            target_freq = f[t_peak]
            R = calculate_walkout_R2(ori_data, dt, target_freq)
            sum_ratio = linearity_measure_by_sum_ratio(ori_data, dt, target_freq)
            schuster_significance, ss_mag = schuster_test_for_phasor_walkout(ori_data, dt, target_freq)
            #DL_prob = get_DL_prob(ori_data, dt, target_freq, DL_model_raw)
            #plt.text(target_freq, fft_amp[t_peak], '{:.8f} mHz\nR$^2$: {:.3f}\nSchuster: {:.3f}\nCIPSR: {:.3f}'.format(target_freq*1e3, R, schuster_significance, sum_ratio), fontsize=12, weight='bold', ha='center', va='bottom')
        """
        plt.xlabel('Frequency (Hz)', fontsize=18)
        plt.ylabel('Power', fontsize=18)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        # use scientific notation for y axis
        #plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
        ev_time = event_dict[key]['time']
        #plt.title('{}'.format(event_dict[key]['name']), fontsize=18)
        mag = event_dict[key]['name'].split('_')[0]
        ev_name = '_'.join(event_dict[key]['name'].split('_')[1:])
        plt.title('{}'.format(ev_name), fontsize=18)
        plt.text(0.05, 0.8, '{}'.format(mag), fontsize=28, transform=plt.gca().transAxes, weight='bold', zorder=15, color='k')

        plt.subplot(total_plot_num, 4, plot_keys.index(key)*4 + 2)
        cur_plot_idx = plot_keys.index(key)*4 + 2
        if cur_plot_idx == 2:
            plt.text(-0.03, 1.08, '(b)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 6:
            plt.text(-0.03, 1.08, '(f)', transform=plt.gca().transAxes, size=20, weight='bold')            
        elif cur_plot_idx == 10:
            plt.text(-0.03, 1.08, '(j)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 14:
            plt.text(-0.03, 1.08, '(n)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 18:
            plt.text(-0.03, 1.08, '(r)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 22:
            plt.text(-0.03, 1.08, '(v)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 26:
            plt.text(-0.03, 1.08, '(z)', transform=plt.gca().transAxes, size=20, weight='bold')

        visualize_walkout_for_subplot_sum_with_period(before_data, dt, before_freq)
        R2 = calculate_walkout_R2(before_data, dt, before_freq)
        sum_ratio = linearity_measure_by_sum_ratio(before_data, dt, before_freq)
        schuster_significance, ss_log = schuster_test_for_phasor_walkout(before_data, dt, before_freq)
        DL_prob = get_DL_prob(before_data, dt, before_freq, DL_model_raw)[0]
        plt.text(0.03, 0.10, 'R$^2$: {:.3f}\nSS: {:.3f}\nSS$_m$$_a$$_g$: {:.3f}\nCIPSR: {:.3f}\nDL_prob: {:.3f}'.format(R2, schuster_significance, ss_log, sum_ratio, DL_prob), fontsize=14, transform=plt.gca().transAxes, weight='bold', zorder=10, color='k')
        plt.title('{:.4f} mHz Walkout (Before)'.format(before_freq*1e3), fontsize=18)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)

        plt.subplot(total_plot_num, 4, plot_keys.index(key)*4 + 3)
        cur_plot_idx = plot_keys.index(key)*4 + 3
        if cur_plot_idx == 3:
            plt.text(-0.03, 1.08, '(c)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 7:
            plt.text(-0.03, 1.08, '(g)', transform=plt.gca().transAxes, size=20, weight='bold')            
        elif cur_plot_idx == 11:
            plt.text(-0.03, 1.08, '(k)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 15:
            plt.text(-0.03, 1.08, '(o)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 19:
            plt.text(-0.03, 1.08, '(s)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 23:
            plt.text(-0.03, 1.08, '(w)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 27:
            plt.text(-0.03, 1.08, '(aa)', transform=plt.gca().transAxes, size=20, weight='bold')

        visualize_walkout_for_subplot_sum_with_period(after_data, dt, after_freq)
        R2 = calculate_walkout_R2(after_data, dt, after_freq)
        sum_ratio = linearity_measure_by_sum_ratio(after_data, dt, after_freq)
        schuster_significance, ss_log = schuster_test_for_phasor_walkout(after_data, dt, after_freq)
        DL_prob = get_DL_prob(after_data, dt, after_freq, DL_model_raw)[0]
        plt.text(0.03, 0.10, 'R$^2$: {:.3f}\nSS: {:.3f}\nSS$_m$$_a$$_g$: {:.3f}\nCIPSR: {:.3f}\nDL_prob: {:.3f}'.format(R2, schuster_significance, ss_log, sum_ratio, DL_prob), fontsize=14, transform=plt.gca().transAxes, weight='bold', zorder=10, color='k')
        plt.title('{:.4f} mHz Walkout (After)'.format(after_freq*1e3), fontsize=18)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)

        plt.subplot(total_plot_num, 4, plot_keys.index(key)*4 + 4)
        cur_plot_idx = plot_keys.index(key)*4 + 4
        if cur_plot_idx == 4:
            plt.text(-0.03, 1.08, '(d)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 8:
            plt.text(-0.03, 1.08, '(h)', transform=plt.gca().transAxes, size=20, weight='bold')            
        elif cur_plot_idx == 12:
            plt.text(-0.03, 1.08, '(l)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 16:
            plt.text(-0.03, 1.08, '(p)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 20:
            plt.text(-0.03, 1.08, '(t)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 24:
            plt.text(-0.03, 1.08, '(x)', transform=plt.gca().transAxes, size=20, weight='bold')
        elif cur_plot_idx == 28:
            plt.text(-0.03, 1.08, '(ab)', transform=plt.gca().transAxes, size=20, weight='bold')


        visualize_walkout_for_subplot_sum_with_period(after_data, dt, check_spherical_mode_frequency)
        R2 = calculate_walkout_R2(after_data, dt, check_spherical_mode_frequency)
        sum_ratio = linearity_measure_by_sum_ratio(after_data, dt, check_spherical_mode_frequency)
        schuster_significance, ss_log = schuster_test_for_phasor_walkout(after_data, dt, check_spherical_mode_frequency)
        DL_prob = get_DL_prob(after_data, dt, check_spherical_mode_frequency, DL_model_raw)[0]
        plt.text(0.03, 0.10, 'R$^2$: {:.3f}\nSS: {:.3f}\nSS$_m$$_a$$_g$: {:.3f}\nCIPSR: {:.3f}\nDL_prob: {:.3f}'.format(R2, schuster_significance, ss_log, sum_ratio, DL_prob), fontsize=14, transform=plt.gca().transAxes, weight='bold', zorder=10, color='k')
        plt.title('{:.4f} mHz Walkout (Reference)'.format(check_spherical_mode_frequency*1e3), fontsize=18)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)

    plt.tight_layout()

    plt.savefig('./pwr_realworld_no_near_earthquakes/final_comparison_{}_power_spectrum_{}.png'.format(test_time_length_in_sec, check_spherical_mode_name), dpi=500)
    plt.close()


On mode: 0S6
On event: M6.5_VANUATU_ISLANDS_20140101 Time Length: 259200


2025-07-08 13:06:16.615225: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2025-07-08 13:06:18.298640: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8401


On event: M6.6_NORTH_OF_ASCENSION_ISLAND_20170818 Time Length: 259200
On event: M7.2_OAXACA_MEXICO_20180216 Time Length: 259200
On event: M7.5_SOUTHEASTERN_ALASKA_20130105 Time Length: 259200
On event: M8.1_EAST_OF_KURIL_ISLANDS_20070113 Time Length: 259200
On event: M8.1_NEAR_COAST_OF_CHIAPAS_MEXICO_20170908 Time Length: 259200
On event: JP_M9.1 Time Length: 259200
On mode: 0S9
On event: M6.5_VANUATU_ISLANDS_20140101 Time Length: 259200
On event: M6.6_NORTH_OF_ASCENSION_ISLAND_20170818 Time Length: 259200
On event: M7.2_OAXACA_MEXICO_20180216 Time Length: 259200
On event: M7.5_SOUTHEASTERN_ALASKA_20130105 Time Length: 259200
On event: M8.1_EAST_OF_KURIL_ISLANDS_20070113 Time Length: 259200
On event: M8.1_NEAR_COAST_OF_CHIAPAS_MEXICO_20170908 Time Length: 259200
On event: JP_M9.1 Time Length: 259200
On mode: 0S13
On event: M6.5_VANUATU_ISLANDS_20140101 Time Length: 259200
On event: M6.6_NORTH_OF_ASCENSION_ISLAND_20170818 Time Length: 259200
On event: M7.2_OAXACA_MEXICO_20180216 Time Le