# インポート

In [1]:
# import packages

# python
import os
from glob import glob
import math
from typing import Dict, Tuple
import sys

# IPython
from IPython.display import display

# numpy
import numpy as np

# scipy
import scipy.signal
from scipy.optimize import curve_fit
from scipy import integrate

# sympy
import sympy

# bokeh
from bokeh.plotting import output_notebook, figure, show
output_notebook()

# local
from WavData import WavData
import FileUtil
import SignalProcessingUtil as spu
import NotebookUtil as nu

# データディレクトリ内のファイルを列挙

In [2]:
# enumerate
top_dir = os.environ.get('PWD')
display(f'top_dir={top_dir}')
wav_files = glob(os.path.join(top_dir, 'data/*.wav'))

# dump
wav_files

'top_dir=/kick_extractor'

['/kick_extractor/data/blue_eyes.wav',
 '/kick_extractor/data/Freaking Tight - Alex Prospect_2.wav',
 '/kick_extractor/data/last_goodbye_mob.wav',
 '/kick_extractor/data/WDYWFM - Alex Prospect_2.wav_est.wav_est.wav',
 '/kick_extractor/data/SHOT ME DOWN - Alex Prospect 2_2.wav',
 '/kick_extractor/data/Tremor - Alex Prospect & WILSXN_2.wav',
 '/kick_extractor/data/02 Dreamer (7iva & saqwz Remix).wav',
 '/kick_extractor/data/satellite.wav',
 '/kick_extractor/data/happy_days_refrain_20180421_2.wav',
 '/kick_extractor/data/save_a_life.wav',
 '/kick_extractor/data/Heaven 2017 - Alex Prospect_2.wav',
 '/kick_extractor/data/stay_young.wav',
 '/kick_extractor/data/Never Forget You - Alex Prospect_2.wav',
 '/kick_extractor/data/WDYWFM - Alex Prospect_2.wav_est.wav',
 '/kick_extractor/data/look_back.wav',
 '/kick_extractor/data/WDYWFM - Alex Prospect_2.wav',
 '/kick_extractor/data/Need U 100% - Alex Prospect & Spyro_2.wav',
 '/kick_extractor/data/all_about_elysium.wav',
 '/kick_extractor/data/INT

# 入力ファイルをロード

In [3]:
# TODO 決め打ち
#src_path = '/kick_extractor/data/WDYWFM - Alex Prospect_2.wav_est.wav'
src_path = '/kick_extractor/data/WDYWFM - Alex Prospect_2.wav'

# load
wavdata_raw = FileUtil.load_wav_file(src_path)

# plot
nu.plot( nu.describe_wavdata('original waveform', wavdata_raw) )

  sample_rate, samples = wavfile.read(src_path)


# pre-process wav data

In [4]:
# TODO 決め打ちのパラメータ
beat_per_minute = 170.0
beat_length_in_sec = 60.0 / 170.0
peak_filter_initial_frequency_in_hz = 80.0
peak_filter_bandwidth_in_hz = 100.0 #200.0

# モノラルにミックスダウン
if wavdata_raw.samples.ndim != 1:
    wavdata_mono = WavData(wavdata_raw.sample_rate, np.mean(wavdata_raw.samples, axis=0))
else:
    wavdata_mono = wavdata_raw
    
# 先頭 1/8 音符分を切り出し
beat_length_in_sample = beat_length_in_sec * wavdata_raw.sample_rate
wavdata_truncated = WavData(wavdata_mono.sample_rate, wavdata_mono.samples[:int(beat_length_in_sample/2)])

# 最も大きい正弦波以外の周波数帯をカットオフして「キレイ」にする
wavdata_smoothed_max = np.zeros_like(wavdata_truncated.samples)
wavdata_smoothed_min = np.zeros_like(wavdata_truncated.samples)
for f in [f for f in np.geomspace(40, 240, 256)]:
    for q in [q for q in np.linspace(100, 200, 16)]:
        wavdata_filtered = spu.apply_peak_filter(wavdata_truncated, peak_filter_bandwidth_in_hz, f).samples
        wavdata_smoothed_max = np.fmax(wavdata_smoothed_max, wavdata_filtered)
        wavdata_smoothed_min = np.fmin(wavdata_smoothed_min, wavdata_filtered)
wavdata_smoothed = WavData(wavdata_truncated.sample_rate, wavdata_smoothed_max + wavdata_smoothed_min)

# リネーム
wavdata_prepro = wavdata_smoothed
    
# 結果をプロット
nu.plot(
    nu.describe_wavdata('original waveform', wavdata_truncated) +
    nu.describe_wavdata('smoothed waveform', wavdata_smoothed) +
    nu.describe_wavdata('pre-processed', wavdata_prepro)
)

# 極値を探索

In [5]:
# 結合
min_positions_in_sample = np.asarray(scipy.signal.argrelmin(wavdata_prepro.samples, order=4)).flatten()
max_positions_in_sample = np.asarray(scipy.signal.argrelmax(wavdata_prepro.samples, order=4)).flatten()
raw_extrema_positions_in_sample = np.sort(np.concatenate([min_positions_in_sample,max_positions_in_sample]).flatten())
display(type(raw_extrema_positions_in_sample))

# プロット
nu.plot(
    nu.describe_dot_on_wavdata('extrema', wavdata_prepro, raw_extrema_positions_in_sample) +
    nu.describe_wavdata('wavdata_prepro', wavdata_prepro)
)

numpy.ndarray

# 隣接＆同符号の極値をグループとみなして代表点で置き換える
- 以下の条件を両方とも満たす極値でグループを作る
    - 隣接している
    - 符号が同一である
- グループごとに代表点を算出して、オリジナルの極値を置き換える
    - 振幅を重みとしたサンプル位置の平均＝代表点

In [6]:
# for ループで頑張る
extrema_positions_in_sample = []
position_buffer = []
value_buffer = []
for position in raw_extrema_positions_in_sample:
    value = wavdata_prepro.samples[position]
    sign = value > 0.0
    if len(position_buffer) != 0:
        last_position = position_buffer[-1]
        last_value = value_buffer[-1]
        last_sign = last_value > 0.0
        if sign != last_sign:
            averaged_position = np.average(position_buffer, weights=np.abs(value_buffer))
            extrema_positions_in_sample.append(averaged_position)
            position_buffer.clear()
            value_buffer.clear()
    position_buffer.append(position)
    value_buffer.append(value)
extrema_positions_in_sample = np.asarray(extrema_positions_in_sample).astype(np.int32)
    
# プロット
nu.plot(
    nu.describe_dot_on_wavdata('extrema', wavdata_prepro, extrema_positions_in_sample) +
    nu.describe_wavdata('wavdata_prepro', wavdata_prepro)
)

# ゼロクロス点（サブサンプル精度）を抽出
- ゼロをまたぐ２点に直線を当てはめて

In [7]:
# 全てのゼロクロス点を抽出
zero_cross_point_in_sample = np.asarray( np.nonzero( ( wavdata_prepro.samples[:-1] > 0.0 ) ^ ( np.roll(wavdata_prepro.samples, -1)[:-1] > 0.0 ) ) ).flatten()

# インライア近傍のゼロクロス点を抽出
zero_cross_point_in_sample = zero_cross_point_in_sample[(extrema_positions_in_sample[0] < zero_cross_point_in_sample) & (zero_cross_point_in_sample < extrema_positions_in_sample[-1])]

# サブサンプルオフセットを推定
#
# ゼロクロス点前後のサンプルを通る直線は
# y = (samples[i+1] - samples[i]) * x + samples[i]
# これを y = 0 として x について解くと
# x = -samples[i] / (samples[i+1] - samples[i])
# これをベクトル演算で頑張ると↓になる
slope = wavdata_prepro.samples[zero_cross_point_in_sample+1] - wavdata_prepro.samples[zero_cross_point_in_sample]
intercept = wavdata_prepro.samples[zero_cross_point_in_sample]
subsample_offset = -intercept / slope

# サブサンプル精度のゼロクロス点を計算
zero_cross_points_in_subsamples = zero_cross_point_in_sample + subsample_offset

# プロット
nu.plot(
    nu.describe_dot_on_wavdata('zero-cross', wavdata_prepro, zero_cross_points_in_subsamples) +
    nu.describe_wavdata('wavdata_prepro', wavdata_prepro)
)

# 極値位置をサブサンプル精度化
- トゥルーピーク（インターサンプルピーク）を得る
- 極値と隣接するサンプルの合計３点でいい感じに重み付き平均みたいなことをする

In [8]:
def extrema_position_to_subsample(samples: np.ndarray, position_in_sample: int):
    '''
    position_in_sample をサブサンプル精度化して返却する。
    '''
    # エイリアス
    magnitude_left = abs(samples[position_in_sample-1])
    magnitude_center = abs(samples[position_in_sample])
    magnitude_right = abs(samples[position_in_sample+1])
    magnitude_1st = magnitude_center
    magnitude_2nd = max(magnitude_left, magnitude_right)
    magnitude_3rd = min(magnitude_left, magnitude_right)
    # サブサンプル単位のオフセットを計算
    subsample_offset = 0.5 * (magnitude_2nd - magnitude_3rd) / (magnitude_1st - magnitude_3rd)
    if magnitude_left > magnitude_right:
        return position_in_sample - subsample_offset
    else:
        return position_in_sample + subsample_offset

# サブサンプル精度化
extrema_positions_in_subsamples = np.array([extrema_position_to_subsample(wavdata_prepro.samples, position_in_sample) for position_in_sample in extrema_positions_in_sample])

# プロット
nu.plot(
    nu.describe_dot_on_wavdata('extrema', wavdata_prepro, extrema_positions_in_sample) +
    nu.describe_dot_on_wavdata('extrema(sub)', wavdata_prepro, extrema_positions_in_subsamples) +
    nu.describe_wavdata('wavdata_prepro', wavdata_prepro)
)

In [9]:
# DEBUG
control_points_in_subsamples = np.sort( np.concatenate([extrema_positions_in_subsamples, zero_cross_points_in_subsamples]) )

# プロット
nu.plot(
    nu.describe_dot_on_wavdata('control point', wavdata_prepro, control_points_in_subsamples) +
    nu.describe_wavdata('wavdata_prepro', wavdata_prepro),
    beat_per_minute=beat_per_minute
)

In [10]:
nu.plot(
    nu.describe_frequency('extrema(subsamples)', extrema_positions_in_subsamples, wavdata_prepro.sample_rate, 2) +
    nu.describe_frequency('zero-cross(subsamples)', zero_cross_points_in_subsamples, wavdata_prepro.sample_rate, 2) +
    nu.describe_frequency('ctrl(subsamples)', control_points_in_subsamples, wavdata_prepro.sample_rate, 4),
    is_log_scale=True
)
nu.plot(
    nu.describe_frequency('extrema(subsamples)', extrema_positions_in_subsamples, wavdata_prepro.sample_rate, 2) +
    nu.describe_frequency('zero-cross(subsamples)', zero_cross_points_in_subsamples, wavdata_prepro.sample_rate, 2),
    is_log_scale=True
)
nu.plot(
    nu.describe_frequency('ctrl(subsamples)', control_points_in_subsamples, wavdata_prepro.sample_rate, 4),
    is_log_scale=True
)

# 制御点リストを生成
- 極値とゼロクロス点をマージする

In [11]:
# 一旦雑にマージ
control_points_in_subsamples = np.concatenate([zero_cross_points_in_subsamples, extrema_positions_in_subsamples])
control_points_sort_arg = np.argsort(control_points_in_subsamples)
control_points_in_subsamples = control_points_in_subsamples[control_points_sort_arg]

# 交互性が保たれていない先頭要素を削除
threshold = np.max(wavdata_prepro.samples[zero_cross_points_in_subsamples.astype(np.int32)])
is_extrema = wavdata_prepro.samples[control_points_in_subsamples.astype(np.int32)]
is_alt = is_extrema != np.roll(is_extrema, -1)
is_alt[-1] = True
control_points_in_subsamples =  control_points_in_subsamples[np.where(is_alt)[0][0]:]

# 交互性が保たれていない末尾要素を削除
is_extrema = wavdata_prepro.samples[control_points_in_subsamples.astype(np.int32)]
is_alt = is_extrema != np.roll(is_extrema, +1)
is_alt[0] = True
control_points_in_subsamples =  control_points_in_subsamples[:np.where(is_alt)[0][-1]+1]

# 残りの制御点情報も生成
control_points_in_samples = control_points_in_subsamples.astype(np.int32)
control_points_in_sec = control_points_in_subsamples / float(wavdata_prepro.sample_rate)
control_points_in_magnitude = wavdata_prepro.samples[control_points_in_samples]

# 結果をプロット
nu.plot(
    nu.describe_dot_on_wavdata('extrema', wavdata_prepro, control_points_in_subsamples) +
    nu.describe_wavdata('wavdata_prepro', wavdata_prepro)
)

# 位相情報を生成

In [12]:
# 位相パターン
threshold = np.max(wavdata_prepro.samples[zero_cross_points_in_subsamples.astype(np.int32)])
initial_quad_phase = None
if (np.abs(control_points_in_magnitude[0]) > threshold):
    if control_points_in_magnitude[0] > 0.0:
        initial_quad_phase = 1 # 0.25
    else:
        initial_quad_phase = 3 # 0.75
else:
    if control_points_in_magnitude[1] > 0.0:
        initial_quad_phase = 0 # 0.00
    else:
        initial_quad_phase = 2 # 0.50

# 位相リストを生成
control_points_in_phase_unit = (np.arange(control_points_in_subsamples.size) + initial_quad_phase) / 4

# プロット
nu.plot(
    nu.describe_scatter('source', control_points_in_phase_unit, control_points_in_sec),
    is_log_scale=False,
    beat_per_minute=beat_per_minute
)

# キックの数式モデルを定義
- 「秒 --> Hz」の関数を定義
- 「秒 --> 位相の」関数は自動積分で得る

In [13]:
# シンボルを定義
sym_x = sympy.Symbol('x')
sym_exp_base = sympy.Symbol('exp_base')
sym_x_scaler = sympy.Symbol('x_scaler')
sym_x_offset = sympy.Symbol('x_offset')
sym_y_scaler = sympy.Symbol('y_scaler')
sym_y_offset = sympy.Symbol('y_offset')
sym_phase_offset = sympy.Symbol('phase_offset')

# 式を定義
exp_2freq = sym_y_scaler * sym_exp_base ** (sym_x_scaler * sym_x + sym_x_offset) + sym_y_offset
exp_2phase = sympy.integrate(exp_2freq, sym_x) + sym_phase_offset

# 式を TeX レンダリング
display(exp_2freq)
display(exp_2phase)

# sympy 式 --> numpy 式
target_function_2freq = sympy.lambdify(
    (sym_x, sym_exp_base, sym_x_scaler, sym_x_offset, sym_y_scaler, sym_y_offset),
    exp_2freq,
    'numpy'
)
target_function_2phase = sympy.lambdify(
    (sym_x, sym_exp_base, sym_x_scaler, sym_x_offset, sym_y_scaler, sym_y_offset, sym_phase_offset),
    exp_2phase,
    'numpy'
)


exp_base**(x*x_scaler + x_offset)*y_scaler + y_offset

phase_offset + x*y_offset + Piecewise((exp_base**(x*x_scaler + x_offset)*y_scaler/(x_scaler*log(exp_base)), Ne(x_scaler*log(exp_base), 0)), (x*y_scaler, True))

# 位相エンベロープで当てはめを行う
- 位相（sin 関数の入力）空間上で、キック数式モデルの当てはめを行う

In [22]:
# 層化ランダム選択（インデックス返却）
def stratrified_random_choice(stop: int, div: int):
    '''
    [0, stop) の範囲を等間隔な div 個の区間に分割し、
    区間ごとに１つの要素をランダム選択する。
    全区間からのランダムサンプリングと比べて「まんべんなく」選択される。
    '''
    regions = np.linspace(0, stop, div+1).astype(np.int32)
    regions_size = np.roll(regions, -1)[:-1] - regions[:-1]
    return (regions[:-1] + regions_size * np.random.rand(regions_size.size)).astype(np.int32)

# LMedS 
def lmeds(
    target_function,
    array_x: np.ndarray,
    array_y: np.ndarray,
    initial_params: list,
    num_iteration: int = 100,
    random_seed: int = 20220116
):
    np.random.seed(seed=random_seed)
    fitting_param = None
    fitting_param_eval = None
    for i in range(num_iteration):
        random_indices = stratrified_random_choice(array_x.size, int(array_x.size/2))
        try:
            temp_fitting_param, _ = curve_fit(
                target_function,
                array_x[random_indices],
                array_y[random_indices],
                p0=initial_params
            )
        except Excpection as e:
            Display(e)
            continue
        temp_residual = target_function(array_x, *temp_fitting_param) - array_y
        temp_fitting_param_eval = np.sort(temp_residual ** 2)[int(array_x.size/2)]
        if fitting_param_eval is None or temp_fitting_param_eval < fitting_param_eval:
            fitting_param = temp_fitting_param
            fitting_param_eval = temp_fitting_param_eval
    return fitting_param


# LMedS でフィッティング
phase_fitting_param = lmeds(
    target_function_2phase,
    control_points_in_sec,
    control_points_in_phase_unit,
    [10.0, -10.0, 0.0, 10.0, 0.0, 0.0],
    100
)

# フィッティングのパラメータを対象関数にバインド
estimated_function_2phase = lambda x: target_function_2phase(x, *phase_fitting_param)

# 推定した周波数関数を元に波形を再生
estimated_positions_in_sample = np.arange(0, beat_length_in_sample/2)
estimated_positions_in_sec = estimated_positions_in_sample / wavdata_prepro.sample_rate
estimated_phases_in_unit = estimated_function_2phase(estimated_positions_in_sec)
estimated_samples = np.sin(estimated_phases_in_unit * 2.0 * np.pi)

# フィッティングの結果をプロット
estimated_wavdata = WavData(wavdata_prepro.sample_rate, estimated_samples)
nu.plot(
    nu.describe_wavdata('original waveform', wavdata_prepro) +
    nu.describe_wavdata('estimated waveform', estimated_wavdata)
)
nu.plot(
    nu.describe_scatter('estimated', estimated_phases_in_unit, estimated_positions_in_sec) +
    nu.describe_scatter('source', control_points_in_phase_unit, control_points_in_sec),
    is_log_scale=False,
    beat_per_minute=beat_per_minute
)

# 開始・終了位相を調整
- 波形の開始・終了時点で任意の位相がちょうどくるように調整する

# 再構築波形をファイル出力

In [None]:
dst_path = src_path + '_est.wav'
FileUtil.save_wav_file(dst_path, estimated_wavdata)

# やってみたけど、結局要らなくなったことリスト

## 検出した極値が正負交互に来る範囲のくくりだし
- 中央らへんからスタートして「正 --> 負 --> 正 --> 負 --> ...」っていう規則が守られている範囲を広げていく
- 「隣接する同符号の極値をグループ化して代表点で置き換える」処理を適用した後なら「正負ルール」が必ず守られているので不要になった

## 前処理でのアップサンプリング
- アップサンプルによって、より正確なピーク位置を算出しようとした
- きれいな丸みのピークの頂上が凹むような現象が現れた
- 極値のサブサンプル化時に問題になるので却下

## 周波数ドメインでの当てはめ
- 制御点の間隔を元に、サンプル位置ごとに周波数を算出し、それに対して周波数関数を当てはめる
- 周波数当てはめでは開始位相を解決できない
- 直接位相を推定できたので必要なくなった