In [None]:
import io
import json
import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output, Image

from noise_pipeline.spectrogram_modifier import SpectrogramModifier
from noise_pipeline.factories import ShapeFactory, PatternFactory
from noise_pipeline.noise_pipeline import NoisePipeline


def log_debug(msg):
    """디버그 메시지를 debug_output 위젯에 출력합니다."""
    with debug_output:
        print(msg)


def create_silence_signal(sr, duration):
    signal_length = int(sr * duration)
    log_debug(f"Silence signal length: {signal_length}")
    return np.zeros(signal_length)


def plot_spectrogram_to_image(spectro_mod):
    log_debug("Plotting spectrogram")
    fig, _ = spectro_mod.plot_spectrogram(show_labels=True)
    buf = io.BytesIO()
    fig.savefig(buf, format="png")
    buf.seek(0)
    plt.close(fig)
    log_debug("Spectrogram image 생성 완료")
    return Image(buf.read(), embed=True)


# Base Spectrogram 기본 설정
base_params = {
    "sample_rate": 16000,
    "duration": 5.0,
    "n_fft": 1024,
    "hop_length": 512,
    "window": "hann",
    "noise_strength": -40,
}

spectro_mod = SpectrogramModifier(
    sample_rate=base_params["sample_rate"],
    n_fft=base_params["n_fft"],
    hop_length=base_params["hop_length"],
    window=base_params["window"],
    noise_strength=base_params["noise_strength"],
)
pipeline = NoisePipeline(spectro_mod)
shape_factory = ShapeFactory()
pattern_factory = PatternFactory()

added_noises = []

# 옵션 설정
shape_options = [
    "circle", "trapezoid", "rectangle", "ellipse", "fog", "hill",
    "polygon", "horizontal_spike", "vertical_spike",
    "pillar", "horizontal_line", "vertical_line",
    "horizontal_range_dist_db", "vertical_range_dist_db",
]
pattern_options = [
    "linear", "random", "n_linear_repeat_t_sleep", "convex", "function",
]

default_shape_params_dict = {
    "circle": (
        '{"center_freq": 4000, "center_time": 2.5, '
        '"radius_freq": 200, "radius_time": 0.5, "strength_dB": 5}'
    ),
    "trapezoid": (
        '{"freq_min": 300, "freq_max": 5000, '
        '"time_min": 0.5, "time_max": 3.0, '
        '"slope_freq": 1.0, "slope_time": 1.0, "strength_dB": 5}'
    ),
    "rectangle": (
        '{"freq_min": 300, "freq_max": 5000, '
        '"time_min": 0.5, "time_max": 3.0, "strength_dB": 5}'
    ),
    "ellipse": (
        '{"center_freq": 4000, "center_time": 2.5, '
        '"radius_freq": 200, "radius_time": 0.5, "strength_dB": 5}'
    ),
    "fog": (
        '{"strength_dB": 5, "coverage": 1.0}'
    ),
    "hill": (
        '{"freq_center": 4000, "time_center": 2.5, '
        '"freq_width": 500, "time_width": 0.5, "strength_dB": 5}'
    ),
    "polygon": (
        '{"vertices": [[300, 0.5], [4000, 2.5], [5000, 3.0]], '
        '"strength_dB": 5}'
    ),
    "horizontal_spike": (
        '{"center_freq": 4000, "center_time": 2.5, '
        '"radius_freq": 100, "radius_time": 0.5, "strength_dB": 5}'
    ),
    "vertical_spike": (
        '{"center_freq": 4000, "center_time": 2.5, '
        '"radius_freq": 100, "radius_time": 0.5, "strength_dB": 5}'
    ),
    "pillar": (
        '{"freq_min": 300, "freq_max": 5000, "strength_dB": 5}'
    ),
    "horizontal_line": (
        '{"center_freq": 4000, "strength_dB": 5, "thickness": 2}'
    ),
    "vertical_line": (
        '{"center_time": 2.5, "strength_dB": 5, "thickness": 2}'
    ),
    "horizontal_range_dist_db": (
        '{"freq_min": 300, "freq_max": 5000, "strength_dB": 5, '
        '"distribution": "gaussian", "distribution_params": {"sigma": 1000}}'
    ),
    "vertical_range_dist_db": (
        '{"time_min": 0.5, "time_max": 3.0, "strength_dB": 5, '
        '"distribution": "gaussian", "distribution_params": {"sigma": 0.5}}'
    ),
}

default_pattern_params_dict = {
    "linear": (
        '{"shape_name": "circle", "shape_params": {"center_freq": 4000, '
        '"center_time": 2.5, "radius_freq": 200, "radius_time": 0.5, '
        '"strength_dB": 5}, "repeat": 3, "spacing": 1.0, '
        '"direction": "time"}'
    ),
    "random": (
        '{"shape_name": "circle", "shape_params": {"center_freq": 4000, '
        '"center_time": 2.5, "radius_freq": 200, "radius_time": 0.5, '
        '"strength_dB": 5}, "n": 5, '
        '"freq_range": [300, 5000], "time_range": [0.5, 3.0]}'
    ),
    "n_linear_repeat_t_sleep": (
        '{"shape_name": "circle", "shape_params": {"center_freq": 4000, '
        '"center_time": 2.5, "radius_freq": 200, "radius_time": 0.5, '
        '"strength_dB": 5}, "repeat": 3, "repeat_time": 0.5, '
        '"sleep_time": 1.0, "start_time": 0.0, "direction": "time"}'
    ),
    "convex": (
        '{"shape_name": "circle", "shape_params": {"center_freq": 4000, '
        '"center_time": 2.5, "radius_freq": 200, "radius_time": 0.5, '
        '"strength_dB": 5}, "freq_min": 300, "freq_max": 5000, '
        '"time_min": 0.5, "time_max": 3.0, "n": 5}'
    ),
    "function": (
        '{"func": "lambda ff, tt: np.sin(ff)*np.cos(tt)*5"}'
    ),
}


# Base Spectrogram 위젯 구성
sr_text = widgets.BoundedIntText(
    value=base_params["sample_rate"],
    min=8000,
    max=48000,
    description="Sample Rate:",
    layout=widgets.Layout(width="250px"),
)
dur_text = widgets.BoundedFloatText(
    value=base_params["duration"],
    min=1.0,
    max=60.0,
    step=0.5,
    description="Duration(s):",
    layout=widgets.Layout(width="250px"),
)
nfft_text = widgets.BoundedIntText(
    value=base_params["n_fft"],
    min=256,
    max=8192,
    step=256,
    description="n_fft:",
    layout=widgets.Layout(width="250px"),
)
hop_text = widgets.BoundedIntText(
    value=base_params["hop_length"],
    min=64,
    max=2048,
    step=64,
    description="hop_length:",
    layout=widgets.Layout(width="250px"),
)
window_text = widgets.Text(
    value=base_params["window"],
    description="Window:",
    layout=widgets.Layout(width="250px"),
)
noise_strength_text = widgets.BoundedFloatText(
    value=base_params["noise_strength"],
    min=-100.0,
    max=100.0,
    step=0.05,
    description="Noise Strength:",
    layout=widgets.Layout(width="250px"),
)
update_base_btn = widgets.Button(
    description="Base 설정 업데이트", button_style="primary"
)


def update_base_settings(b):
    global spectro_mod, pipeline, base_params
    base_params["sample_rate"] = sr_text.value
    base_params["duration"] = dur_text.value
    base_params["n_fft"] = nfft_text.value
    base_params["hop_length"] = hop_text.value
    base_params["window"] = window_text.value
    base_params["noise_strength"] = noise_strength_text.value

    spectro_mod = SpectrogramModifier(
        sample_rate=base_params["sample_rate"],
        n_fft=base_params["n_fft"],
        hop_length=base_params["hop_length"],
        window=base_params["window"],
        noise_strength=base_params["noise_strength"],
    )
    pipeline.spectro_mod = spectro_mod
    log_debug(f"Base 설정 업데이트됨: {base_params}")
    status_label.value = "Base 설정 업데이트됨"


update_base_btn.on_click(update_base_settings)

base_box = widgets.VBox(
    [
        widgets.Label("Base Spectrogram 설정"),
        widgets.HBox([sr_text, dur_text]),
        widgets.HBox([nfft_text, hop_text]),
        widgets.HBox([window_text, noise_strength_text]),
        update_base_btn,
    ]
)

# Shape Noise 위젯 구성
shape_dropdown = widgets.Dropdown(
    options=shape_options,
    description="Shape 타입:",
    value=shape_options[0],
)
shape_text = widgets.Textarea(
    value=default_shape_params_dict.get(shape_options[0], "{}"),
    description="Shape 파라메터:",
    layout=widgets.Layout(width="90%", height="100px"),
)
add_shape_btn = widgets.Button(
    description="Shape Noise 추가", button_style="info"
)


def update_shape_params(change):
    sel = change["new"]
    default_val = default_shape_params_dict.get(sel, "{}")
    log_debug(f"선택된 Shape: {sel}, 기본 파라메터: {default_val}")
    shape_text.value = default_val


shape_dropdown.observe(update_shape_params, names="value")

# Pattern Noise 위젯 구성
pattern_dropdown = widgets.Dropdown(
    options=pattern_options,
    description="Pattern 타입:",
    value=pattern_options[0],
)
pattern_text = widgets.Textarea(
    value=default_pattern_params_dict.get(pattern_options[0], "{}"),
    description="Pattern 파라메터:",
    layout=widgets.Layout(width="90%", height="100px"),
)
add_pattern_btn = widgets.Button(
    description="Pattern Noise 추가", button_style="info"
)


def update_pattern_params(change):
    sel = change["new"]
    default_val = default_pattern_params_dict.get(sel, "{}")
    log_debug(f"선택된 Pattern: {sel}, 기본 파라메터: {default_val}")
    pattern_text.value = default_val


pattern_dropdown.observe(update_pattern_params, names="value")

generate_btn = widgets.Button(
    description="노이즈 생성", button_style="success"
)
clear_btn = widgets.Button(
    description="리스트 초기화", button_style="warning"
)

noise_list = widgets.Output()
status_label = widgets.Label(value="대기 중")
spectrogram_output = widgets.Output()
debug_output = widgets.Output()


def add_shape_noise(b):
    try:
        params = json.loads(shape_text.value)
    except Exception as e:
        with noise_list:
            print(f"[DEBUG] Shape JSON 오류: {e}")
        return
    noise_type = shape_dropdown.value
    try:
        noise_obj = shape_factory.create(noise_type, **params)
    except Exception as e:
        with noise_list:
            print(f"[DEBUG] Shape 생성 실패: {e}")
        return
    pipeline.add_shape(noise_obj)
    added_noises.append(f"Shape: {noise_type} | {json.dumps(params)}")
    update_noise_list()
    log_debug(f"Shape '{noise_type}' 추가됨, 파라메터: {params}")
    status_label.value = f"Shape '{noise_type}' 추가됨"


def add_pattern_noise(b):
    try:
        params = json.loads(pattern_text.value)
    except Exception as e:
        with noise_list:
            print(f"[DEBUG] Pattern JSON 오류: {e}")
        return
    noise_type = pattern_dropdown.value
    try:
        noise_obj = pattern_factory.create(noise_type, params)
    except Exception as e:
        with noise_list:
            print(f"[DEBUG] Pattern 생성 실패: {e}")
        return
    pipeline.add_pattern(noise_obj)
    added_noises.append(f"Pattern: {noise_type} | {json.dumps(params)}")
    update_noise_list()
    log_debug(f"Pattern '{noise_type}' 추가됨, 파라메터: {params}")
    status_label.value = f"Pattern '{noise_type}' 추가됨"


def update_noise_list():
    with noise_list:
        clear_output()
        for item in added_noises:
            print(item)


def clear_noise_list(b):
    global added_noises
    added_noises = []
    pipeline.shapes = []
    pipeline.patterns = []
    update_noise_list()
    log_debug("Noise 리스트 초기화됨")
    status_label.value = "Noise 리스트 초기화됨"


def run_pipeline():
    try:
        log_debug("노이즈 생성 시작")
        signal = create_silence_signal(
            base_params["sample_rate"], base_params["duration"]
        )
        log_debug("Silence signal 생성됨")
        pipeline.generate(signal)
        log_debug("pipeline.generate() 호출 완료")
        sdb_min = spectro_mod.S_db.min()
        sdb_max = spectro_mod.S_db.max()
        log_debug(f"S_db 범위: {sdb_min} ~ {sdb_max}")
        img = plot_spectrogram_to_image(spectro_mod)
        with spectrogram_output:
            clear_output(wait=True)
            display(img)
        log_debug("Spectrogram 출력 완료")
        status_label.value = "노이즈 생성 완료"
    except Exception as e:
        status_label.value = "오류 발생"
        with noise_list:
            print(f"[DEBUG] 실행 오류: {e}")


def generate_noise(b):
    status_label.value = "노이즈 생성 중..."
    # 스레드를 사용하지 않고 직접 실행하여 디버그 메시지가 표시되도록 함.
    run_pipeline()


add_shape_btn.on_click(add_shape_noise)
add_pattern_btn.on_click(add_pattern_noise)
clear_btn.on_click(clear_noise_list)
generate_btn.on_click(generate_noise)

shape_box = widgets.VBox([shape_dropdown, shape_text, add_shape_btn])
pattern_box = widgets.VBox([pattern_dropdown, pattern_text, add_pattern_btn])
control_box = widgets.HBox([generate_btn, clear_btn])
list_box = widgets.VBox([widgets.Label("추가된 Noise 목록:"), noise_list])
debug_box = widgets.VBox([widgets.Label("디버그 출력:"), debug_output])
main_box = widgets.VBox(
    [
        base_box,
        widgets.Label("Shape Noise 추가"),
        shape_box,
        widgets.Label("Pattern Noise 추가"),
        pattern_box,
        control_box,
        status_label,
        list_box,
        widgets.Label("생성된 Spectrogram:"),
        spectrogram_output,
        debug_box,
    ]
)

display(main_box)


VBox(children=(VBox(children=(Label(value='Base Spectrogram 설정'), HBox(children=(BoundedIntText(value=16000, d…