In [None]:
import numpy as np
import pyximport
pyximport.install()
import f12
import sys
import time
from numba import njit, prange
from scipy.signal import hilbert
#--------------------------------------------------------------------------
machine_epsilon = sys.float_info.epsilon

In [None]:
@njit(parallel=True, fastmath=True, cache=True)
def x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, ttu):
    WW = np.zeros((L1, L2))
    for i in range(L1):
        for j in prange(L2):
            s_tim = np.zeros(n_r)
            for k in range(n_r):
                s_tim[k] = np.floor(UU2[k][i, j] / dt) + 1  # 注意: 需要确认UU2的结构
            L11 = np.max(s_tim) - s_tim
            L22 = L11 + w_L2
            L2_max = np.max(L22)
            U1 = np.zeros((n_r, int(L2_max)))
            for iL in range(n_r):
                start_idx = int(L11[iL])
                end_idx = int(L11[iL] + w_L2)
                U1[iL, start_idx:end_idx] = ttu[iL, :]
            for ii in range(1, n_r):  # 从第二行开始
                U1[0, :] = U1[0, :] * U1[ii, :]
            WW[i, j] = np.sum(U1[0, :])
    
    return WW


    

@njit(parallel=True, fastmath=True, cache=True)
def cal_sta(w, n_r, w_L2, machine_epsilon):
    R_L = int(np.floor(w_L2 / 5))      
    R_S = int(np.floor(w_L2 / 150))     
    w_sta = w.copy()
    for i in range(n_r):
        current_row = w[i, :]
        for j in prange(w_L2):
            sta_start = max(0, j - R_S)
            sta_end = min(j + R_S + 1, w_L2)  
            lta_start = max(0, j - R_L)
            lta_end = min(j + R_L + 1, w_L2)  
            STA = current_row[sta_start:sta_end]
            LTA = current_row[lta_start:lta_end]
            sta_mean = np.mean(np.abs(STA))
            lta_mean = np.mean(np.abs(LTA))+machine_epsilon
            w_sta[i,j] = sta_mean / (lta_mean)
    return w_sta
    




def cal_env(w, n_r):
    w_env= w.copy()
    for i in prange(n_r):
        current_row = w[i, :]
        analytic_signal = hilbert(current_row)  
        hilbert_transform = np.imag(analytic_signal)  
        w_env[i, :] = np.sqrt(np.abs(current_row)**2 + np.abs(hilbert_transform)**2)
    return w_env




@njit(parallel=True, fastmath=True, cache=True)
def process_data(TTu2):
    TTU2 = np.ascontiguousarray(TTu2)
    tu0_L1, w_L2 = TTU2.shape
    TTu2_abs = np.abs(TTU2)
    p = np.zeros(tu0_L1)
    for ij in prange(tu0_L1):
        row = TTu2_abs[ij, :]
        p[ij] = np.max(row) - np.min(row)
    h = np.sum(TTu2_abs, axis=0) ** 2
    tu = np.zeros_like(TTu2_abs)
    for i in prange(tu0_L1):
        for j in range(w_L2):
            tu[i, j] = p[i] * TTu2_abs[i, j] * h[j]
    ttu = np.sum(tu**2, axis=0)
    return ttu

@njit(parallel=True, fastmath=True, cache=True)
def st_numba(t, s, freqlow, freqhigh, alpha):
    TimeLen = len(t)
    dt = t[1] - t[0]
    nLevel = int((freqhigh - freqlow) / alpha) + 1
    fre = np.linspace(freqlow, freqhigh, nLevel)
    # 预分配结果数组
    wcoefs = np.zeros((nLevel, TimeLen), dtype=np.complex64)
    phase_factors = np.zeros((nLevel, TimeLen), dtype=np.complex64)
    for m in range(nLevel):
        f = fre[m]
        phase_factors[m, :] = np.exp(-1.0j * 2 * np.pi * f * t)  # 向量化计算
    
    # 预计算高斯窗的归一化因子
    sigma_f = 1.0 / fre
    sqrt_2pi = np.sqrt(2 * np.pi)
    norm_factors = 1.0 / (sqrt_2pi * sigma_f)
    
    # 计算每个频率点 
    for m in prange(nLevel):
        f = fre[m]
        current_sigma = sigma_f[m]
        sigma_sq = current_sigma * current_sigma  # 预计算方差
        
        # 获取当前频率预计算的相位因子
        current_phase_factor = phase_factors[m, :]
        current_norm_factor = norm_factors[m]
        
        for n in range(TimeLen):
            center_time = n * dt
            temp_sum = 0.0 + 0.0j  # 初始化累加器
            window_half_width = int(min(3 * current_sigma / dt, TimeLen))
            start_idx = max(0, n - window_half_width)
            end_idx = min(TimeLen, n + window_half_width + 1)
            
            for k in range(start_idx, end_idx):
                time_diff = k * dt - center_time
                exponent = -0.5 * (time_diff * time_diff) / sigma_sq
                gauss_window = current_norm_factor * np.exp(exponent)
                windowed_value = gauss_window * current_phase_factor[k]
                temp_sum += s[k] * windowed_value
            
            wcoefs[m, n] = temp_sum * dt
            
    return wcoefs










In [None]:
loaded_data = np.load('Ori_wave.npz')
W_wave = loaded_data['W_wave']
t = loaded_data['t']
ssx = loaded_data['ssx']
ssy = loaded_data['ssy']
V = loaded_data['V']
st11 = loaded_data['st11']
st22 = loaded_data['st22']
x = loaded_data['x']
y = loaded_data['y']
n = V.shape[0]
n_r = len(st11) 

In [None]:
dt=t[2]-t[1]
h = 2
X, Y = np.meshgrid(x, y)  
L1, L2 = X.shape   
K = 20
node2,node1=V.shape
UU2 = []
for i in range(n_r):
    u = np.full(V.shape, np.inf)
    u[st11[i],st22[i]] = 0  
    UU2_i=f12.f1_cython(u,h,1/V,0.01,node1,node2,K)
    UU2.append(UU2_i)
UU2=np.array(UU2)


In [None]:
# 须先进行编译，再测试时间
Time=np.zeros(3)
BEI=np.linspace(0,10,11)
W1=np.zeros((n,n,len(ssx),len(BEI)))
W_ttu=np.zeros((n_r,len(t),len(ssx),len(BEI)))
ttu=np.zeros(W_wave[:,:,0,0].shape)
#---------------------------------------------------
start_tim=time.time()
for bb in range(len(BEI)):
    for ci in range(W_wave.shape[2]):
        w= W_wave[:,:,ci,bb]

        for i in range(0,n_r):
            wcoefs =st_numba(t,w[i,:], 1, 500, 50)
            ttu[i,:]=process_data(wcoefs)
        w_L2=ttu.shape[1]
        WW1=x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, ttu)
end_tim=time.time()
Time[0]=end_tim-start_tim
#---------------------------------------------------
start_tim=time.time()
W2=np.zeros((n,n,len(ssx),len(BEI)))
W_sta=np.zeros((n_r,len(t),len(ssx),len(BEI)))
for bb in range(len(BEI)):
    for ci in range(W_wave.shape[2]):
        w= W_wave[:,:,ci,bb]
        w_sta=cal_sta(w, n_r, w_L2, machine_epsilon)
        WW2=x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, w_sta)
end_tim=time.time()
Time[1]=end_tim-start_tim
#---------------------------------------------------
start_tim=time.time()
W3=np.zeros((n,n,len(ssx),len(BEI)))
W_env=np.zeros((n_r,len(t),len(ssx),len(BEI)))
for bb in range(len(BEI)):
    for ci in range(W_wave.shape[2]):
        w= W_wave[:,:,ci,bb]
        w_env = cal_env(w, n_r)
        WW3=x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, w_env)
end_tim=time.time()
Time[2]=end_tim-start_tim

  

In [None]:
BEI=np.linspace(0,10,11)

XXU=np.zeros((3,2,len(ssx),len(BEI)))
W1=np.zeros((n,n,len(ssx),len(BEI)))
W2=np.zeros((n,n,len(ssx),len(BEI)))
W3=np.zeros((n,n,len(ssx),len(BEI)))

W_ttu=np.zeros((n_r,len(t),len(ssx),len(BEI)))
W_sta=np.zeros((n_r,len(t),len(ssx),len(BEI)))
W_env=np.zeros((n_r,len(t),len(ssx),len(BEI)))
for bb in range(len(BEI)):
    for ci in range(W_wave.shape[2]):
        w= W_wave[:,:,ci,bb]
        ttu=w.copy()
        for i in range(0,n_r):
            # wcoefs = st_numba(t,w[i,:], 1, 500, 50)
            wcoefs =st_numba(t,w[i,:], 1, 500, 50)
            ttu[i,:]=process_data(wcoefs)
        w_L2=ttu.shape[1]
        WW1=x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, ttu)
        w_sta=cal_sta(w, n_r, w_L2, machine_epsilon)
        WW2=x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, w_sta)
        w_env = cal_env(w, n_r)
        WW3=x_corr_n_2d(L1, L2, n_r, UU2, dt, w_L2, w_env)
        W1[:,:,ci,bb]=WW1.copy()
        W2[:,:,ci,bb]=WW2.copy()
        W3[:,:,ci,bb]=WW3.copy()
        xu = []
        plot_data = [WW1,WW2,WW3]
        for WW in plot_data:
            flat_idx = np.argmax(WW)
            ii, jj = np.unravel_index(flat_idx, WW.shape)
            xu.append((ii, jj))
        XXU[:,:,ci,bb]=np.array(xu).copy()
        W_wave[:,:,ci,bb]=w.copy()
        W_ttu[:,:,ci,bb]=ttu.copy()
        W_sta[:,:,ci,bb]=w_sta.copy()
        W_env[:,:,ci,bb]=w_env.copy()

In [None]:
# np.savez('RESULT_2d_new.npz', W1=W1, W2=W2, W3=W3, XXU=XXU, W_wave=W_wave, W_ttu=W_ttu, W_sta=W_sta, W_env=W_env, st11=st11, st22=st22, ssy=ssy, ssx=ssx, V=V,x=x,y=y,t=t,Time=Time)