In [1]:
import os
import sys

# 设置seisbench缓存路径 - 必须在导入seisbench之前设置
CACHE_ROOT = '/mnt/data/tianyu/seisbench_cache'
os.environ['SEISBENCH_CACHE_ROOT'] = CACHE_ROOT
import numpy as np
import h5py
import pandas as pd
from pathlib import Path
from PyEMD import EMD
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
import seisbench.data as sbd
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import time


  import pkg_resources


In [2]:
# load dataset
dataset = sbd.STEAD()
print(f"Dataset loaded with {len(dataset)} samples.")



Dataset loaded with 1265657 samples.


In [3]:
# get one of EMD result
emd = EMD(max_imfs=3, max_ierations=500)
sample = dataset.get_waveforms(0)
sample_IMFs = []
for i in range(3):
    IMFs = emd(sample[i])
    print(f"Component {i} has {IMFs.shape[0]} IMFs.")
    first_3 = IMFs[:3,:]  # only keep first 3 IMFs
    # Padding zero if less than 3 IMFs
    if IMFs.shape[0] < 3:
        padding = np.zeros((3 - IMFs.shape[0], IMFs.shape[1]))
        first_3 = np.vstack((IMFs, padding)) 
    sample_IMFs.append(first_3)    
sample_IMFs = np.array(sample_IMFs)  # shape (3, 3, 6000) 
print(f"Sample IMFs shape: {sample_IMFs.shape}")
print(sample_IMFs)


Component 0 has 12 IMFs.
Component 1 has 12 IMFs.
Component 2 has 12 IMFs.
Sample IMFs shape: (3, 3, 6000)
[[[ 1.19052005e+00  1.32109511e+00  1.29347217e+00 ...  7.29397312e-03
    1.91898586e-03 -1.44385053e-02]
  [-5.97637296e-01 -6.46309972e-01 -6.46760583e-01 ...  2.10213828e+00
    2.17930651e+00  2.20022917e+00]
  [-1.29614699e+00 -1.31218767e+00 -1.20875382e+00 ...  5.67584112e-03
   -1.17643634e-02 -7.61724077e-03]]

 [[ 1.19052005e+00  1.32109511e+00  1.29347217e+00 ...  7.29397312e-03
    1.91898586e-03 -1.44385053e-02]
  [-5.97637296e-01 -6.46309972e-01 -6.46760583e-01 ...  2.10213828e+00
    2.17930651e+00  2.20022917e+00]
  [-1.29614699e+00 -1.31218767e+00 -1.20875382e+00 ...  5.67584112e-03
   -1.17643634e-02 -7.61724077e-03]]

 [[ 1.19052005e+00  1.32109511e+00  1.29347217e+00 ...  7.29397312e-03
    1.91898586e-03 -1.44385053e-02]
  [-5.97637296e-01 -6.46309972e-01 -6.46760583e-01 ...  2.10213828e+00
    2.17930651e+00  2.20022917e+00]
  [-1.29614699e+00 -1.31218767e+0

In [4]:
sample_IMFs

array([[[ 1.19052005e+00,  1.32109511e+00,  1.29347217e+00, ...,
          7.29397312e-03,  1.91898586e-03, -1.44385053e-02],
        [-5.97637296e-01, -6.46309972e-01, -6.46760583e-01, ...,
          2.10213828e+00,  2.17930651e+00,  2.20022917e+00],
        [-1.29614699e+00, -1.31218767e+00, -1.20875382e+00, ...,
          5.67584112e-03, -1.17643634e-02, -7.61724077e-03]],

       [[ 1.19052005e+00,  1.32109511e+00,  1.29347217e+00, ...,
          7.29397312e-03,  1.91898586e-03, -1.44385053e-02],
        [-5.97637296e-01, -6.46309972e-01, -6.46760583e-01, ...,
          2.10213828e+00,  2.17930651e+00,  2.20022917e+00],
        [-1.29614699e+00, -1.31218767e+00, -1.20875382e+00, ...,
          5.67584112e-03, -1.17643634e-02, -7.61724077e-03]],

       [[ 1.19052005e+00,  1.32109511e+00,  1.29347217e+00, ...,
          7.29397312e-03,  1.91898586e-03, -1.44385053e-02],
        [-5.97637296e-01, -6.46309972e-01, -6.46760583e-01, ...,
          2.10213828e+00,  2.17930651e+00,  2.200

In [None]:
# now process all dataset and save to hdf5
from multiprocessing import Pool
import numpy as np
from PyEMD import EMD
def process_channel(args):
    i, channel_data = args
    emd_instance = EMD(max_imfs=3, max_iterations=500)
    IMFs = emd_instance(channel_data)
    # times['emd'].append(time.time() - start_emd)
    
    first_3 = IMFs[:3,:]
    if IMFs.shape[0] < 3:
        padding = np.zeros((3 - IMFs.shape[0], IMFs.shape[1]))
        first_3 = np.vstack((IMFs, padding)) 
    return first_3

def process_sample(index):

    sample = dataset.get_waveforms(index)
    # times['data_prep'].append(time.time() - start_data)
    
    # 准备参数
    channel_args = [(i, sample[i]) for i in range(3)]
    
    # 使用进程池并行处理
    with Pool(processes=3) as pool:
        sample_IMFs = pool.map(process_channel, channel_args)
    
    sample_IMFs = np.array(sample_IMFs)  # shape (3, 3, 6000) 
    return sample_IMFs

num_samples = len(dataset)
hdf5_path = Path('./STEAD_emd.hdf5')

batch_size = 128
with h5py.File(hdf5_path, 'w') as hdf5_file:
    dset = hdf5_file.create_dataset('IMFs', shape=(num_samples, 3, 3, 6000), dtype=np.float32)
    
    for batch_start in tqdm(range(0, num_samples, batch_size)):
        batch_end = min(batch_start + batch_size, num_samples)
        batch_indices = range(batch_start, batch_end)
        
        with ThreadPoolExecutor(max_workers=64) as executor:
            results = list(executor.map(process_sample, batch_indices))
        
        # Write batch to HDF5
        for i, result in enumerate(results):
            dset[batch_start + i] = result

  0%|          | 4/9888 [01:33<64:17:54, 23.42s/it]Process ForkPoolWorker-1828:
Process ForkPoolWorker-1818:
Process ForkPoolWorker-1846:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/pool.p

In [None]:
import time
from collections import defaultdict
from multiprocessing import Pool
import numpy as np
from PyEMD import EMD
from concurrent.futures import ThreadPoolExecutor
import h5py
from pathlib import Path
from tqdm import tqdm
# 全局计时字典
times = defaultdict(list)
def process_channel(args):
    i, channel_data = args
    start_emd = time.time()
    
    emd_instance = EMD(max_imfs=3, max_iterations=500)
    IMFs = emd_instance(channel_data)
    
    times['emd_channel'].append(time.time() - start_emd)
    
    first_3 = IMFs[:3,:]
    if IMFs.shape[0] < 3:
        padding = np.zeros((3 - IMFs.shape[0], IMFs.shape[1]))
        first_3 = np.vstack((IMFs, padding)) 
    return first_3
def process_sample(index):
    # 数据获取计时
    start_data = time.time()
    sample = dataset.get_waveforms(index)
    data_time = time.time() - start_data
    times['get_waveforms'].append(data_time)
    
    # 准备参数
    channel_args = [(i, sample[i]) for i in range(3)]
    
    # 使用进程池并行处理
    start_parallel = time.time()
    with Pool(processes=3) as pool:
        sample_IMFs = pool.map(process_channel, channel_args)
    times['parallel_processing'].append(time.time() - start_parallel)
    
    sample_IMFs = np.array(sample_IMFs)  # shape (3, 3, 6000) 
    return sample_IMFs
def print_performance_stats():
    """打印性能统计信息"""
    print("\n=== 性能分析报告 ===")
    total_samples = len(times['get_waveforms'])
    
    for key in times:
        if times[key]:
            avg_time = np.mean(times[key])
            total_time = np.sum(times[key])
            max_time = np.max(times[key])
            min_time = np.min(times[key])
            print(f"{key}:")
            print(f"  平均时间: {avg_time:.4f}s")
            print(f"  总时间: {total_time:.2f}s")
            print(f"  最大时间: {max_time:.4f}s")
            print(f"  最小时间: {min_time:.4f}s")
            print(f"  样本数: {len(times[key])}")
    
    # 计算各阶段占比
    total_processing_time = sum(np.sum(times[key]) for key in times if key != 'get_waveforms')
    data_loading_time = np.sum(times['get_waveforms'])
    total_time = data_loading_time + total_processing_time
    
    if total_time > 0:
        print(f"\n时间分布:")
        print(f"  数据加载: {data_loading_time:.2f}s ({data_loading_time/total_time*100:.1f}%)")
        print(f"  EMD处理: {np.sum(times['emd_channel']):.2f}s ({np.sum(times['emd_channel'])/total_time*100:.1f}%)")
        print(f"  并行开销: {np.sum(times['parallel_processing']):.2f}s ({np.sum(times['parallel_processing'])/total_time*100:.1f}%)")
# 主处理循环
num_samples = len(dataset)
hdf5_path = Path('./STEAD_emd.hdf5')
batch_size = 32
try:
    with h5py.File(hdf5_path, 'w') as hdf5_file:
        dset = hdf5_file.create_dataset('IMFs', shape=(num_samples, 3, 3, 6000), dtype=np.float32)
        
        # 添加批处理计时
        batch_times = []
        
        for batch_start in tqdm(range(0, num_samples, batch_size), desc="Processing batches"):
            batch_start_time = time.time()
            
            batch_end = min(batch_start + batch_size, num_samples)
            batch_indices = range(batch_start, batch_end)
            
            with ThreadPoolExecutor(max_workers=16) as executor:
                results = list(executor.map(process_sample, batch_indices))
            
            # Write batch to HDF5
            hdf5_write_time = time.time()
            for i, result in enumerate(results):
                dset[batch_start + i] = result
            times['hdf5_write'].append(time.time() - hdf5_write_time)
            
            batch_time = time.time() - batch_start_time
            batch_times.append(batch_time)
            times['batch_total'].append(batch_time)
            
            # 每处理几个批次就打印一次中间统计
            if len(batch_times) % 5 == 0:
                avg_batch_time = np.mean(batch_times[-5:])
                print(f"\n最近5批平均处理时间: {avg_batch_time:.2f}s")
                if 'get_waveforms' in times and times['get_waveforms']:
                    avg_data_time = np.mean(times['get_waveforms'][-batch_size*5:])
                    print(f"最近数据加载平均时间: {avg_data_time:.4f}s")
finally:
    # 最终性能报告
    print_performance_stats()
    
    # 保存性能数据到文件
    performance_data = {k: np.array(v) for k, v in times.items()}
    np.savez('performance_stats.npz', **performance_data)
    print("性能数据已保存到 performance_stats.npz")

Processing batches:   0%|          | 0/39552 [00:00<?, ?it/s]

Processing batches:   0%|          | 5/39552 [00:35<85:45:14,  7.81s/it]


最近5批平均处理时间: 7.16s
最近数据加载平均时间: 0.9869s


Processing batches:   0%|          | 10/39552 [01:08<74:28:09,  6.78s/it]


最近5批平均处理时间: 6.56s
最近数据加载平均时间: 1.0470s


Processing batches:   0%|          | 15/39552 [01:39<70:31:15,  6.42s/it]


最近5批平均处理时间: 6.24s
最近数据加载平均时间: 1.0227s


Processing batches:   0%|          | 16/39552 [01:45<68:46:32,  6.26s/it]Process ForkPoolWorker-2099:
Process ForkPoolWorker-2104:
Process ForkPoolWorker-2109:
Process ForkPoolWorker-2108:
Process ForkPoolWorker-2101:
Process ForkPoolWorker-2105:
Process ForkPoolWorker-2102:
Process ForkPoolWorker-2093:
Process ForkPoolWorker-2100:
Process ForkPoolWorker-2107:
Process ForkPoolWorker-2106:
Process ForkPoolWorker-2103:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/chenty/.local/share/uv/python/cpython-3.9.23-linux-x86_64-gnu/lib/python3.9/multiprocessing/

In [6]:
import time
from collections import defaultdict
import numpy as np
from PyEMD import EMD
from concurrent.futures import ThreadPoolExecutor
import h5py
from pathlib import Path
from tqdm import tqdm
# 全局计时字典
times = defaultdict(list)
# 预初始化EMD实例（避免重复创建的开销）
emd_instances = [EMD(max_imfs=3, max_iterations=500) for _ in range(3)]
def process_sample_serial(index):
    # 数据获取计时
    start_data = time.time()
    sample = dataset.get_waveforms(index)
    times['get_waveforms'].append(time.time() - start_data)
    
    sample_IMFs = []
    start_emd_total = time.time()
    
    for i in range(3):
        start_emd = time.time()
        IMFs = emd_instances[i](sample[i])  # 使用预初始化的实例
        times['emd_single'].append(time.time() - start_emd)
        
        first_3 = IMFs[:3,:]
        if IMFs.shape[0] < 3:
            padding = np.zeros((3 - IMFs.shape[0], IMFs.shape[1]))
            first_3 = np.vstack((IMFs, padding)) 
        sample_IMFs.append(first_3)
    
    times['emd_total'].append(time.time() - start_emd_total)
    
    return np.array(sample_IMFs)
# 主处理循环
num_samples = len(dataset)
hdf5_path = Path('./STEAD_emd.hdf5')
batch_size = 32  # 可以调大一些，因为串行处理
try:
    with h5py.File(hdf5_path, 'w') as hdf5_file:
        dset = hdf5_file.create_dataset('IMFs', shape=(num_samples, 3, 3, 6000), dtype=np.float32)
        
        batch_times = []
        
        for batch_start in tqdm(range(0, 160, batch_size), desc="Processing batches"):
            batch_start_time = time.time()
            
            batch_end = min(batch_start + batch_size, num_samples)
            batch_indices = range(batch_start, batch_end)
            
            # 使用串行处理
            results = []
            for index in batch_indices:
                results.append(process_sample_serial(index))
            
            # Write batch to HDF5
            hdf5_write_time = time.time()
            for i, result in enumerate(results):
                dset[batch_start + i] = result
            times['hdf5_write'].append(time.time() - hdf5_write_time)
            
            batch_time = time.time() - batch_start_time
            batch_times.append(batch_time)
            times['batch_total'].append(batch_time)
            
            # 实时性能监控
            if len(batch_times) % 2 == 0:  # 每2批显示一次
                avg_batch_time = np.mean(batch_times[-2:])
                print(f"\n最近2批平均处理时间: {avg_batch_time:.2f}s")
                if times['get_waveforms']:
                    recent_data_times = times['get_waveforms'][-batch_size*2:]
                    avg_data_time = np.mean(recent_data_times) if recent_data_times else 0
                    print(f"数据加载平均时间: {avg_data_time:.4f}s")
                if times['emd_single']:
                    recent_emd_times = times['emd_single'][-batch_size*2*3:]  # 2批 * 3通道
                    avg_emd_time = np.mean(recent_emd_times) if recent_emd_times else 0
                    print(f"单通道EMD平均时间: {avg_emd_time:.4f}s")
finally:
    # 最终性能报告
    print("\n=== 最终性能分析报告 ===")
    total_samples = len(times['get_waveforms'])
    
    for key in times:
        if times[key]:
            avg_time = np.mean(times[key])
            total_time = np.sum(times[key])
            print(f"{key}: 平均{avg_time:.4f}s, 总计{total_time:.2f}s, 样本数{len(times[key])}")
    
    # 计算各阶段占比
    data_loading_time = np.sum(times['get_waveforms'])
    emd_processing_time = np.sum(times['emd_total'])
    hdf5_time = np.sum(times['hdf5_write'])
    total_time = data_loading_time + emd_processing_time + hdf5_time
    
    if total_time > 0:
        print(f"\n时间分布:")
        print(f"  数据加载: {data_loading_time:.2f}s ({data_loading_time/total_time*100:.1f}%)")
        print(f"  EMD处理: {emd_processing_time:.2f}s ({emd_processing_time/total_time*100:.1f}%)")
        print(f"  HDF5写入: {hdf5_time:.2f}s ({hdf5_time/total_time*100:.1f}%)")
    
    # 保存性能数据
    performance_data = {k: np.array(v) for k, v in times.items()}
    np.savez('performance_stats_serial.npz', **performance_data)
    print("性能数据已保存到 performance_stats_serial.npz")

Processing batches:   0%|          | 0/5 [00:00<?, ?it/s]

Processing batches:  40%|████      | 2/5 [00:22<00:33, 11.14s/it]


最近2批平均处理时间: 11.14s
数据加载平均时间: 0.0023s
单通道EMD平均时间: 0.1151s


Processing batches:  80%|████████  | 4/5 [00:48<00:12, 12.19s/it]


最近2批平均处理时间: 13.33s
数据加载平均时间: 0.0022s
单通道EMD平均时间: 0.1380s


Processing batches: 100%|██████████| 5/5 [01:11<00:00, 14.20s/it]


=== 最终性能分析报告 ===
get_waveforms: 平均0.0023s, 总计0.36s, 样本数160
emd_single: 平均0.1470s, 总计70.58s, 样本数480
emd_total: 平均0.4411s, 总计70.58s, 样本数160
hdf5_write: 平均0.0113s, 总计0.06s, 样本数5
batch_total: 平均14.2017s, 总计71.01s, 样本数5

时间分布:
  数据加载: 0.36s (0.5%)
  EMD处理: 70.58s (99.4%)
  HDF5写入: 0.06s (0.1%)
性能数据已保存到 performance_stats_serial.npz





In [None]:

print(data.shape)

(3, 3, 6000)
