# パッケージインストール

In [None]:
!pip install -q soundfile onnx onnxruntime

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h

# 重みダウンロード

In [None]:
!wget https://github.com/Xiaobin-Rong/gtcrn/raw/refs/heads/main/stream/onnx_models/gtcrn_simple.onnx -q

# サンプル音声ダウンロード

In [None]:
!wget https://github.com/Xiaobin-Rong/gtcrn/raw/refs/heads/main/stream/test_wavs/mix.wav -q

In [None]:
import soundfile as sf
from IPython.display import Audio

waveform, sampling_rate = sf.read("mix.wav", dtype="float32")
if waveform.ndim > 1:
    waveform = np.mean(waveform, axis=1)

Audio(waveform, rate=sampling_rate)

# モデル準備

In [None]:
import numpy as np
import onnxruntime

# モデル読み込み、関連変数初期化
model = onnxruntime.InferenceSession(
    "gtcrn_simple.onnx", providers=["CPUExecutionProvider"]
)
# GTCRN内部状態保持用変数
conv_cache = np.zeros([2, 1, 16, 16, 33], dtype=np.float32)
tra_cache = np.zeros([2, 3, 1, 1, 16], dtype=np.float32)
inter_cache = np.zeros([2, 1, 33, 16], dtype=np.float32)

In [None]:
%%time

import copy

buffer = copy.deepcopy(waveform)

# 設定
window_size = 512
n_fft = 512
hop_size = 256
window = np.sqrt(np.hanning(window_size))

# 出力用バッファ
output_audio = np.zeros(len(buffer) + window_size, dtype=np.float32)
output_index = 0

while len(buffer) > 0:
    # チャンクデータを取得、ウィンドウサイズに満たない場合はゼロパディング
    chunk = buffer[:window_size]
    if len(chunk) < window_size:
        chunk = np.pad(
            chunk, (0, window_size - len(chunk)), constant_values=0
        )

    # フーリエ変換
    chunk_spec = np.fft.rfft(chunk * window, n=n_fft)
    real = np.real(chunk_spec)
    imag = np.imag(chunk_spec)
    input_data = np.stack([real, imag], axis=-1)[None, :, None, :]  # (1, 257, 1, 2)

    # 推論
    output_data, conv_cache, tra_cache, inter_cache = model.run(
        None,
        {
            "mix": input_data.astype(np.float32),
            "conv_cache": conv_cache,
            "tra_cache": tra_cache,
            "inter_cache": inter_cache,
        },
    )

    # 逆実数フーリエ変換
    out_real = output_data[0][:, 0, 0]
    out_imag = output_data[0][:, 0, 1]
    enhanced_spec = out_real + 1j * out_imag

    time_chunk= np.fft.irfft(enhanced_spec, n=n_fft)[:window_size]

    # オーバーラップ加算
    time_chunk = time_chunk * window
    output_audio[output_index : output_index + window_size] += time_chunk

    # 位置更新
    output_index += hop_size
    buffer = buffer[hop_size:]

CPU times: user 1.88 s, sys: 36.1 ms, total: 1.92 s
Wall time: 2.31 s


In [None]:
# Wav保存
sf.write("output.wav", output_audio[: len(waveform)], sampling_rate)

# 確認
temp_waveform, temp_sampling_rate = sf.read("output.wav", dtype="float32")
if temp_waveform.ndim > 1:
    temp_waveform = np.mean(temp_waveform, axis=1)

Audio(temp_waveform, rate=temp_sampling_rate)