### Common code (run this at first)

In [None]:
import numpy as np
import xarray as xr
from scipy.fft import fft2, ifft2
from scipy.linalg import solve_banded

from tools import loader

def get_time_base_data(flag, time_s, need_data=['u', 'v', 'w', 'theta', 'p', 'advec_u', 'advec_v', 'advec_w', 'bouy_w']):
    dataset = loader.load_base_data(flag=flag, grid_num=1) # 加载数据
    # 筛选需要的时间数据
    time_coords = dataset.coords['time']
    time_nearest = time_coords.sel({"time":np.timedelta64(time_s, 's')}, method='nearest').values # 得到与time, z最接近的索引
    time_index = np.argmax(time_coords.to_numpy() == time_nearest) if np.any(time_coords.to_numpy() == time_nearest) else None
    result_data = {x:(dataset[x][time_index, :, :, :]) for x in need_data}

    dataset_pr = loader.load_base_data(flag=flag, type='pr')
    rho = dataset_pr['rho'][time_index+1]
    rho = rho.to_numpy().reshape(rho.shape[0], 1, 1)
    result_data['rho'] = rho

    return result_data

def get_press_source(input_data):
    results = {}
    rho = input_data['rho']

    def partial_z(data):
        data = data.fillna(0)
        result = (data - np.roll(data, shift=1, axis=0)) / 10
        return result

    def partial_x(data):
        data = data.fillna(0)
        result = (np.roll(data, shift=-1, axis=2) - data) / 20
        return result

    def partial_y(data):
        data = data.fillna(0)
        result = (np.roll(data, shift=-1, axis=1) - data) / 20
        return result

    advect_u = input_data['advec_u']
    advect_v = input_data['advec_v']
    advect_w = input_data['advec_w']
    inertial = partial_x(advect_u) + partial_y(advect_v) + partial_z(advect_w)
    inertial = np.array(inertial)
    inertial[0, :, :] = inertial[1, :, :] # just for Neumann in bottom
    inertial[-1, :, :] = 0                # just for Dirichlet in top
    results['inertial'] = inertial * rho

    buoyancy = partial_z(input_data['bouy_w'])
    buoyancy = np.array(buoyancy)
    buoyancy[0, :, :] = buoyancy[1, :, :] # just for Neumann in bottom
    buoyancy[-1, :, :] = 0                # just for Dirichlet in top
    results['buoyancy'] = buoyancy * rho
    
    return results

def solve_p(f):
    # 确保输入是 (nz, ny, nx)
    f = np.array(f)
    nz, ny, nx = f.shape
    Lx, Ly, Lz = 5000, 5000, 1610
    dx, dy, dz = Lx/nx, Ly/ny, Lz/nz

    # x, y 波数
    kx = 2 * np.pi * np.fft.fftfreq(nx, d=dx)
    ky = 2 * np.pi * np.fft.fftfreq(ny, d=dy)
    KX, KY = np.meshgrid(kx, ky, indexing='ij')  # (nx, ny)
    lam = -(KX**2 + KY**2)  # (nx, ny)
    lam_vec = lam.T.reshape(-1)  # (ny*nx,)

    # x, y方向FFT
    f_hat = np.fft.fft2(f, axes=(2,1))  # (nz, ny, nx)
    f_hat = f_hat / (nx * ny)           # 归一化
    rhs = f_hat.reshape(nz, ny*nx) # reshape为 (nz, ny*nx)

    # 组装三对角带状矩阵只用一遍（所有系统的结构都一样）
    ab = np.zeros((3, nz, ny*nx), dtype=np.complex128)
    ab[1, :, :]    = -(2.0/(dz**2)) + lam_vec  # 主对角
    ab[0, 1:, :]   = 1.0/(dz**2)               # 上对角
    ab[2, :-1, :]  = 1.0/(dz**2)               # 下对角

    # 边界条件 z=0, Neumann
    ab[1, 0, :] += 1.0/(dz**2)

    # 边界条件 z=Lz, Dirichlet
    ab[1, -1, :] = 1
    ab[2, -2, :] = 0
    ab[0, -1, :] = 0
    rhs[-1, :] = 0

    # 解所有ny*nx组z向三对角
    p_hat = np.zeros_like(rhs, dtype=np.complex128)
    for idx in range(ny*nx):
        p_hat[:, idx] = solve_banded((1, 1), ab[:, :, idx], rhs[:, idx])

    # 恢复形状 (nz, ny, nx)
    p_hat = p_hat.reshape(nz, ny, nx)

    # x, y方向逆FFT
    p = np.fft.ifft2(p_hat, axes=(2,1)).real * (nx * ny)
    return p


### Cal press source and save

In [None]:
flag = 'U0'
time_coords = np.array(range(3600, 5410, 10))

w_datas = []
p_datas = []
theta_datas = []
inertial_datas = []
buoyancy_datas = []
diffusion_datas = []
coriolis_datas = []
other_datas = []
for time_s in time_coords:
    print(f"process {time_s}...")
    base_data = get_time_base_data(flag, time_s)
    w_datas.append((base_data['w'] + np.roll(base_data['w'], shift=-1, axis=1)) / 2)
    p_datas.append(base_data['p'])
    theta_datas.append(base_data['theta'])
    
    press_source = get_press_source(base_data)
    inertial_datas.append(press_source['inertial'])
    buoyancy_datas.append(press_source['buoyancy'])

dataset = loader.load_base_data(flag=flag, grid_num=1) # 加载数据
# 使用 xarray 创建 Dataset
ds = xr.Dataset(
    {
        "w": (["time", "z", "y", "x"], np.array(w_datas)),
        "p": (["time", "z", "y", "x"], np.array(p_datas)),
        "theta": (["time", "z", "y", "x"], np.array(theta_datas)),
        "inertial": (["time", "z", "y", "x"], np.array(inertial_datas)),
        "buoyancy": (["time", "z", "y", "x"], np.array(buoyancy_datas)),
    },
    coords={
        "time": time_coords,
        "z": list(dataset.coords['zu_3d']),
        "x": list(dataset.coords['x']),
        "y": list(dataset.coords['y']),
    },
)

# 将数据保存为 NetCDF 文件
save_name = f"./data/{flag}_press_source.nc"
ds.to_netcdf(save_name)
print(f"Already save {save_name}")

### Cal press and save

In [None]:
flag = 'U0'

# 加载数据
source_data = xr.open_dataset(f"./data/{flag}_press_source.nc")

p_inertial_list = []
p_buoyancy_list = []
for time_i in range(source_data['p'].shape[0]):
    print(f'caling {time_i}')
    inertial = source_data['inertial'][time_i][1:]
    buoyancy = source_data['buoyancy'][time_i][1:]

    p_inertial = solve_p(inertial)
    p_buoyancy = solve_p(buoyancy)

    p_inertial_list.append(p_inertial)
    p_buoyancy_list.append(p_buoyancy)

# 使用 xarray 创建 Dataset
ds = xr.Dataset(
    {
        "p_inertial": (["time", "z", "y", "x"], np.array(p_inertial_list)),
        "p_buoyancy": (["time", "z", "y", "x"], np.array(p_buoyancy_list)),
    },
    coords={
        "time": source_data.coords['time'],
        "z": list(source_data.coords['z'][1:]),
        "x": list(source_data.coords['x']),
        "y": list(source_data.coords['y']),
    },
)

# 将数据保存为 NetCDF 文件
save_name = f"./data/{flag}_press_solve.nc"
ds.to_netcdf(save_name)
print(f"Already save {save_name}")

caling 0
caling 1
caling 2
caling 3
caling 4
caling 5
caling 6
caling 7
caling 8
caling 9
caling 10
caling 11
caling 12
caling 13
caling 14
caling 15
caling 16
caling 17
caling 18
caling 19
caling 20
caling 21
caling 22
caling 23
caling 24
caling 25
caling 26
caling 27
caling 28
caling 29
caling 30
caling 31
caling 32
caling 33
caling 34
caling 35
caling 36
caling 37
caling 38
caling 39
caling 40
caling 41
caling 42
caling 43
caling 44
caling 45
caling 46
caling 47
caling 48
caling 49
caling 50
caling 51
caling 52
caling 53
caling 54
caling 55
caling 56
caling 57
caling 58
caling 59
caling 60
caling 61
caling 62
caling 63
caling 64
caling 65
caling 66
caling 67
caling 68
caling 69
caling 70
caling 71
caling 72
caling 73
caling 74
caling 75
caling 76
caling 77
caling 78
caling 79
caling 80
caling 81
caling 82
caling 83
caling 84
caling 85
caling 86
caling 87
caling 88
caling 89
caling 90
caling 91
caling 92
caling 93
caling 94
caling 95
caling 96
caling 97
caling 98
caling 99
caling 100