In [1]:
import glob
import numpy as np
import csv
import pandas as pd
from pathlib import Path

for dataset_type in ["Learning_org", "Test_org"]:
    results = [["filepath", "waveform_size", "waveform_components", "p_idx", "s_idx", "start_idx"]]
    for fn in glob.glob(f"data/{dataset_type}/*.npz"):
        data = np.load(fn)
        waveform = data["wave"]
        waveform_shape = waveform.shape
        components = waveform.shape[0]
        waveform_size = waveform.shape[1] 
        pidx = data["pidx"]
        sidx = data["sidx"]
        if sidx + 500 > 3000:
            if pidx - 500 + 3000 > waveform_size:
                start_idx = waveform_size - 3000
            else:
                start_idx = pidx - 500
        else:
            start_idx = 0
        results.append([fn, waveform_size, components, pidx, sidx, start_idx])

    with open(f"data/{dataset_type}/fileinfo.csv", "w") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(results)

In [2]:
import numpy as np
from pathlib import Path
from tqdm import tqdm
import pandas as pd

def modify_data(waveform, pidx, sidx, start_idx):
    modified_waveform = waveform[:, start_idx: start_idx+3000]
    modified_pidx = pidx - start_idx
    modified_sidx = sidx - start_idx
    return modified_waveform, modified_pidx, modified_sidx

for dataset_type in ["Learning_org", "Test_org"]:
    save_dir = Path(f"data/{dataset_type.replace('_org', '')}")
    save_dir.mkdir(parents=True, exist_ok=True)
    df = pd.read_csv(f"data/{dataset_type}/fileinfo.csv")
    # print(df)
    for (fn, start_idx) in tqdm(df[["filepath", "start_idx"]].values, total=len(df)):
        data = np.load(fn)
        waveform = data["wave"]
        pidx = data["pidx"]
        sidx = data["sidx"]
        modified_waveform, modified_pidx, modified_sidx = modify_data(waveform, pidx, sidx, start_idx)
        np.savez(
            fn.replace("_org", ""), 
            wave=modified_waveform, 
            pidx=modified_pidx,
            sidx=modified_sidx
        )

100%|██████████| 4935/4935 [00:02<00:00, 1980.62it/s]
100%|██████████| 500/500 [00:00<00:00, 2017.59it/s]


In [3]:
import os
import glob
import numpy as np
import torch
from pathlib import Path
from sklearn.model_selection import train_test_split
from obspy import Stream, Trace, UTCDateTime
from tqdm import tqdm
import seisbench.models as sbm

SAMPLE_RATE = 100
WIN_LENGTH = 30
HOP_LENGTH = 15
N_FFT = 60
DURATION = 30 * SAMPLE_RATE
EPS = 1e-8

np.random.seed(5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === DeepDenoiser モデルロード ===
model = sbm.DeepDenoiser.from_pretrained("original")

# === ディレクトリ設定 ===
input_dir = "data/Learning"
output_root = Path("unet_snr/datasets")
train_dir = output_root / "train"
val_dir = output_root / "val"
train_dir.mkdir(parents=True, exist_ok=True)
val_dir.mkdir(parents=True, exist_ok=True)

channels = ["UD", "NS", "EW"]


def compute_stft_with_phase_rotation(wave):
    stft = torch.stft(
        torch.from_numpy(wave),
        n_fft=N_FFT,
        return_complex=True,
        win_length=WIN_LENGTH,
        hop_length=HOP_LENGTH,
        window=torch.hann_window(WIN_LENGTH),
        pad_mode="constant",
        normalized=False,
    )

    stft = (2 / WIN_LENGTH) * stft

    return stft

def process_wave(wave):
    stft = compute_stft_with_phase_rotation(wave)
    stft_real_img = torch.stack([stft.real, stft.imag], dim=0)  # shape: [2, F, T]
    return stft_real_img

def create_mask(wave_stft, denoised_stft, pidx):
    stft_abs = torch.abs(wave_stft[0] + wave_stft[1] * 1j)
    denoised_abs = torch.abs(denoised_stft[0] + denoised_stft[1] * 1j)
    mask = denoised_abs / (stft_abs + EPS)
    mask = torch.clip(mask, 0, 1)
    mask[:, :int(pidx/HOP_LENGTH)-1] = 0.01
    return mask

# === ObsPy変換関数 ===
def convert_ndarry_stream(data, time_str, station_name, sampling_rate=100):
    year = 2000 + int(time_str[:2])
    month, day = int(time_str[2:4]), int(time_str[4:6])
    hour, minute, second = int(time_str[7:9]), int(time_str[9:11]), int(time_str[11:13])
    utc_time = UTCDateTime(year, month, day, hour, minute, second)
    channels = ["UD", "NS", "EW"]
    stream = Stream()
    for i, ch in enumerate(channels):
        trace = Trace(data=data[i])
        trace.stats.update({
            "sampling_rate": sampling_rate,
            "starttime": utc_time,
            "network": "MeSO-net",
            "station": station_name,
            "location": "",
            "channel": ch,
        })
        stream.append(trace)
    return stream

def convert_stream_to_ndarray(stream, channel_order=["UD", "NS", "EW"]):
    traces = []
    for ch in channel_order:
        tr = stream.select(channel=ch)
        if len(tr) == 0:
            raise ValueError(f"Channel {ch} not found in the stream.")
        traces.append(tr[0].data)
    return np.stack(traces)

# === ファイル一覧取得 & 分割 ===
npz_files = sorted(glob.glob(os.path.join(input_dir, "*.npz")))
train_files, val_files = train_test_split(npz_files, test_size=0.1, random_state=42)

# === 処理関数（ノイズ拡張あり）===
def process_and_save(file_list, out_dir, add_noise=True, num_augments=2):
    for fn in tqdm(file_list, desc=f"Processing {out_dir.name}"):
        data = np.load(fn)
        wave = data["wave"].astype(np.float32)  # (3, 3000)
        wave -= np.mean(wave, axis=1, keepdims=True)
        pidx = int(data["pidx"])
        sidx = int(data["sidx"])

        try:
            time_str, station_name = os.path.basename(fn).replace(".npz", "").split("_")
        except Exception as e:
            print(f"[WARN] Skipping invalid filename {fn}: {e}")
            continue

        try:
            original_stream = convert_ndarry_stream(wave, time_str, station_name)
            denoised_stream = model.annotate(original_stream)

            denoised = convert_stream_to_ndarray(
                denoised_stream,
                channel_order=["DeepDenoiser_UD", "DeepDenoiser_NS", "DeepDenoiser_EW"]
            ).astype(np.float32)

        except Exception as e:
            print(f"[WARN] Failed to denoise {fn}: {e}")
            continue

        
        for i in range(3):
            # 元データ保存
            tmp_wave = wave[i]
            tmp_denoised = denoised[i]
            stft_real_img = process_wave(tmp_wave)
            de_stft_real_img = process_wave(tmp_denoised)
            mask = create_mask(stft_real_img, de_stft_real_img, pidx)
            base_name = os.path.basename(fn).replace(".npz", f"_{channels[i]}")

            save_dict = {
                "spec": stft_real_img,
                "de_spec": de_stft_real_img,
                "mask": mask,
                "pidx": pidx,
                "sidx": sidx,
                "name": f"{base_name}.pt"
            }
            torch.save(save_dict, out_dir / save_dict["name"])

            # === ノイズ付き拡張 ===
            if add_noise:
                for aug_id in range(num_augments):
                    noise_strength = np.random.uniform(0.2, 0.5)
                    std_per_channel = np.std(tmp_wave)
                    noise = np.random.normal(
                        scale=std_per_channel * noise_strength, 
                        size=tmp_wave.shape
                    ).astype(np.float32)
                    noisy_wave = tmp_wave + noise
                    noisy_wave -= np.mean(noisy_wave)
                    noisy_stft_real_img = process_wave(noisy_wave)
                    mask_aug = create_mask(noisy_stft_real_img, de_stft_real_img, pidx)

                    aug_dict = {
                        "spec": noisy_stft_real_img,
                        "de_spec": de_stft_real_img,
                        "mask": mask_aug,
                        "pidx": pidx,
                        "sidx": sidx,
                        "name": f"{base_name}_aug{aug_id}.pt"
                    }
                    torch.save(aug_dict, out_dir / aug_dict["name"])

# === 実行 ===
num_augments = 2
add_noise = True
total_per_file = 3 * (1 + num_augments if add_noise else 1)
process_and_save(train_files, train_dir, add_noise=add_noise, num_augments=num_augments)
process_and_save(val_files, val_dir, add_noise=add_noise, num_augments=num_augments)

print(f"✅ Saved {len(train_files) * total_per_file} original + augmented training samples to {train_dir}")
print(f"✅ Saved {len(val_files) * total_per_file} validation samples to {val_dir}")


Processing train: 100%|██████████| 4441/4441 [01:07<00:00, 65.34it/s]
Processing val: 100%|██████████| 494/494 [00:07<00:00, 65.61it/s]

✅ Saved 39969 original + augmented training samples to unet_snr/datasets/train
✅ Saved 4446 validation samples to unet_snr/datasets/val



