# Graphical Equalizer

In [1]:
# No overlapping

import numpy as np
import sounddevice as sd
import ipywidgets as widgets
import pywt
import soundfile as sf
from scipy.fft import dct, idct
from IPython.display import display, clear_output
import time
import threading

# --- Configuration ---
SAMPLE_RATE = 44100
BLOCK_SIZE = 4096
NUM_BANDS = 32
WAVELET_LEVEL = int(np.log2(NUM_BANDS))

# --- DEFAULTS ---
DEFAULT_FILE = "../data/Metallica_Orion.mp3"
DEFAULT_WAV_FAMILY = 'db'
DEFAULT_WAV_FILTER = 'db5'

MAX_SLIDER_GAIN = 1.0 
CURRENT_WAVELET = DEFAULT_WAV_FILTER

# Base Volumes
BASE_NOISE_VOL = 0.2
BASE_TONAL_VOL = 0.05
BASE_FILE_VOL = 0.8 

# --- FREQUENCY CALCULATIONS ---
nyquist = SAMPLE_RATE / 2
bandwidth = nyquist / NUM_BANDS
end_freqs = np.array([(i + 1) * bandwidth for i in range(NUM_BANDS)])
center_freqs = np.array([(i * bandwidth) + (bandwidth / 2) for i in range(NUM_BANDS)])

# --- THRESHOLD OF HEARING MODEL ---
def ToH_model(f):
    f = np.maximum(f, 1e-6) 
    term1 = 3.64 * (f / 1000) ** (-0.8)
    term2 = -6.5 * np.exp(-0.6 * ((f / 1000 - 3.3) ** 2))
    term3 = 10 ** (-3) * (f / 1000) ** 4
    return term1 + term2 + term3

# --- DYNAMIC INITIAL CONFIGURATION ---
toh_raw = ToH_model(center_freqs)
t_min, t_max = np.min(toh_raw), np.max(toh_raw)
toh_normalized = (toh_raw - t_min) / (t_max - t_min) if t_max > t_min else np.zeros_like(toh_raw)

# Store the initial ToH gains for the Reset button
INITIAL_TOH_GAINS = ((1.0 - toh_normalized) * MAX_SLIDER_GAIN).astype(np.float32)
gains = np.copy(INITIAL_TOH_GAINS)

# --- Global State ---
band_active = np.ones(NUM_BANDS, dtype=bool)
band_variances = np.zeros(NUM_BANDS, dtype=np.float32)

global_volume = 1.0 
slider_widgets_list = [] 
meter_widgets_list = [] 
global_stream = None 
stop_threads = False

mix_noise = 0.0
mix_tone = 0.0
mix_file = 1.0
passthrough_mode = False 
process_method = 'Wavelet'

file_data = None
file_play_head = 0
phases = np.zeros(NUM_BANDS) 
available_families = [f for f in pywt.families(short=True) if f not in ['bior', 'rbio', 'dmey', 'gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor']]

# --- 1. The Audio Callback ---
def audio_callback(outdata, frames, time_info, status):
    global phases, file_play_head, band_variances
    if status: print(status)
    
    current_variances = np.zeros(NUM_BANDS)
    
    if global_volume <= 0.001:
        outdata.fill(0)
        band_variances[:] = 0
        return

    final_mix = np.zeros(frames)
    broadband_signal = np.zeros(frames)
    has_broadband_content = False

    # A. Generate Tone
    if mix_tone > 0.01:
        t = np.arange(frames) / SAMPLE_RATE
        tone_accum = np.zeros(frames)
        for i in range(NUM_BANDS):
            if band_active[i] and gains[i] > 0.01:
                freq = center_freqs[i]
                g = gains[i] * BASE_TONAL_VOL * mix_tone
                tone_accum += g * np.sin(2 * np.pi * freq * t + phases[i])
        final_mix += tone_accum

    # B. Generate Broadband Source
    if mix_noise > 0.01:
        broadband_signal += np.random.uniform(-1, 1, size=frames) * BASE_NOISE_VOL * mix_noise
        has_broadband_content = True

    if mix_file > 0.01 and file_data is not None:
        chunk_size = frames
        remaining = len(file_data) - file_play_head
        if remaining >= chunk_size:
            chunk = file_data[file_play_head : file_play_head + chunk_size]
            file_play_head += chunk_size
        else:
            chunk = np.concatenate((file_data[file_play_head:], file_data[:chunk_size - remaining]))
            file_play_head = chunk_size - remaining
        broadband_signal += chunk * BASE_FILE_VOL * mix_file
        has_broadband_content = True

    # C. Process Broadband
    if has_broadband_content:
        processed = np.zeros(frames)
        
        if passthrough_mode:
            processed = broadband_signal
            
            # --- CALIBRATION FIX FOR PASSTHROUGH ---
            # Previously divided by 50, now dividing by 12 to match Wavelet energy.
            try:
                spectrum = np.fft.rfft(broadband_signal)
                bins_per_band = len(spectrum) / NUM_BANDS
                for i in range(NUM_BANDS):
                    s, e = int(i*bins_per_band), int((i+1)*bins_per_band)
                    if e > s:
                        # Changed from / 50 to / 12
                        current_variances[i] = np.mean(np.abs(spectrum[s:e])) / 12.0 
            except: pass
            
        else:
            # --- NORMAL PROCESSING ---
            if process_method == 'Wavelet':
                try:
                    wp = pywt.WaveletPacket(data=broadband_signal, wavelet=CURRENT_WAVELET, mode='symmetric', maxlevel=WAVELET_LEVEL)
                    nodes = wp.get_level(WAVELET_LEVEL, order='freq')
                    count = min(len(nodes), NUM_BANDS)
                    for i in range(count):
                        g = gains[i] if band_active[i] else 0.0
                        
                        processed_node_data = nodes[i].data * g
                        nodes[i].data = processed_node_data
                        
                        if len(processed_node_data) > 0:
                            # Keep this at * 5
                            current_variances[i] = np.mean(np.abs(processed_node_data)) * 5.0
                            
                    processed = wp.reconstruct(update=False)
                except: processed = broadband_signal

            elif process_method == 'FFT':
                spectrum = np.fft.rfft(broadband_signal)
                bins_per_band = len(spectrum) / NUM_BANDS
                gain_mask = np.zeros(len(spectrum))
                for i in range(NUM_BANDS):
                    s, e = int(i*bins_per_band), int((i+1)*bins_per_band)
                    g = gains[i] if band_active[i] else 0.0
                    gain_mask[s:e] = g
                    
                    processed_band = spectrum[s:e] * g
                    if e > s:
                        # FFT Processing mode also needs consistent scaling
                        current_variances[i] = np.mean(np.abs(processed_band)) / 12.0
                        
                processed = np.fft.irfft(spectrum * gain_mask)

            elif process_method == 'MDCT':
                spectrum = dct(broadband_signal, type=4, norm='ortho')
                bins_per_band = len(spectrum) / NUM_BANDS
                gain_mask = np.zeros(len(spectrum))
                for i in range(NUM_BANDS):
                    s, e = int(i*bins_per_band), int((i+1)*bins_per_band)
                    g = gains[i] if band_active[i] else 0.0
                    gain_mask[s:e] = g
                    
                    processed_band = spectrum[s:e] * g
                    if e > s:
                         current_variances[i] = np.mean(np.abs(processed_band)) * 2.0
                         
                processed = idct(spectrum * gain_mask, type=4, norm='ortho')

            if len(processed) != frames:
                processed = np.resize(processed, frames)
        
        final_mix += processed

    band_variances[:] = current_variances[:]
    outdata[:, 0] = final_mix * global_volume

# --- 2. Background UI Updater Thread ---
def ui_updater():
    while not stop_threads:
        time.sleep(0.1) 
        vals = np.copy(band_variances)
        for i, meter in enumerate(meter_widgets_list):
            if i < len(vals):
                val = min(max(vals[i], 0.0), 1.0)
                meter.value = val

# --- 3. UI Construction ---
slider_containers = []
band_mute_widgets = []

def update_gain(change): gains[change['owner'].band_index] = change['new']
def toggle_band_mute(change):
    idx = change['owner'].band_index
    band_active[idx] = change['new']
    change['owner'].icon = 'volume-up' if change['new'] else 'volume-off'
    change['owner'].button_style = 'success' if change['new'] else 'danger'
def set_single_max(b): slider_widgets_list[b.band_index].value = MAX_SLIDER_GAIN
def set_single_min(b): slider_widgets_list[b.band_index].value = 0.0

for i in range(NUM_BANDS):
    start_val = float(gains[i])
    
    mute_btn = widgets.ToggleButton(value=True, button_style='success', icon='volume-up', layout=widgets.Layout(width='30px', height='30px', margin='0px'))
    mute_btn.band_index = i
    mute_btn.observe(toggle_band_mute, names='value')
    band_mute_widgets.append(mute_btn)

    gain_box = widgets.BoundedFloatText(value=start_val, min=0.0, max=MAX_SLIDER_GAIN, step=0.01, layout=widgets.Layout(width='45px', height='30px'))
    
    slider = widgets.FloatSlider(value=start_val, min=0.0, max=MAX_SLIDER_GAIN, step=0.01, orientation='vertical', readout=False, layout=widgets.Layout(width='20px', height='250px'))
    slider.band_index = i
    widgets.jslink((slider, 'value'), (gain_box, 'value'))
    slider.observe(update_gain, names='value')
    slider_widgets_list.append(slider)
    
    meter = widgets.FloatProgress(
        value=0.0, min=0.0, max=1.0, orientation='vertical', bar_style='info',
        layout=widgets.Layout(width='15px', height='250px', margin='0px 0px 0px 5px')
    )
    meter_widgets_list.append(meter)
    
    freq_val = end_freqs[i]
    freq_text = f"{freq_val/1000:.1f}k" if freq_val >= 1000 else f"{int(freq_val)}"
    freq_lbl = widgets.Label(value=freq_text, layout=widgets.Layout(width='40px', justify_content='center'))
    
    btn_max = widgets.Button(description='↑', layout=widgets.Layout(width='15px', height='20px', padding='0px'))
    btn_max.band_index = i; btn_max.on_click(set_single_max)
    btn_min = widgets.Button(description='↓', layout=widgets.Layout(width='15px', height='20px', padding='0px'))
    btn_min.band_index = i; btn_min.on_click(set_single_min)
    
    slider_meter_row = widgets.HBox([slider, meter], layout=widgets.Layout(align_items='center'))

    col = widgets.VBox(
        [mute_btn, gain_box, btn_max, slider_meter_row, btn_min, freq_lbl], 
        layout=widgets.Layout(align_items='center', margin='0px 5px 0px 0px', border='1px solid #eee')
    )
    slider_containers.append(col)

eq_box = widgets.HBox(slider_containers, layout=widgets.Layout(overflow='x scroll', width='100%'))

# --- 4. Controls & Setup ---
file_path_text = widgets.Text(value=DEFAULT_FILE, placeholder='path/to/file.wav', description='File Path:')
load_file_btn = widgets.Button(description='Load', button_style='info', icon='upload')
file_status_lbl = widgets.Label(value="No file")

method_selector = widgets.Dropdown(options=['Wavelet', 'FFT', 'MDCT'], value='Wavelet', description='Method:')
wav_family_selector = widgets.Dropdown(options=available_families, value=DEFAULT_WAV_FAMILY, description='Family:')
# Initialize Name Selector with DEFAULT_WAV_FILTER
initial_options = pywt.wavelist(DEFAULT_WAV_FAMILY)
wav_name_selector = widgets.Dropdown(options=initial_options, value=DEFAULT_WAV_FILTER, description='Filter:')

def update_wav_names(change): 
    wav_name_selector.options = pywt.wavelist(change['new'])
    wav_name_selector.value = wav_name_selector.options[0]

def update_current_wavelet(change): 
    global CURRENT_WAVELET
    CURRENT_WAVELET = change['new']

wav_family_selector.observe(update_wav_names, names='value')
wav_name_selector.observe(update_current_wavelet, names='value')

mix_noise_slider = widgets.FloatSlider(value=mix_noise, min=0.0, max=1.0, step=0.01, description='Noise:')
mix_tone_slider = widgets.FloatSlider(value=mix_tone, min=0.0, max=1.0, step=0.01, description='Tone:')
mix_file_slider = widgets.FloatSlider(value=mix_file, min=0.0, max=1.0, step=0.01, description='File:')

passthrough_btn = widgets.ToggleButton(value=False, description='Passthrough', button_style='', icon='exchange', layout=widgets.Layout(width='120px'))
def toggle_passthrough(change):
    global passthrough_mode
    passthrough_mode = change['new']
    passthrough_btn.button_style = 'warning' if passthrough_mode else ''
    passthrough_btn.description = 'PT Active' if passthrough_mode else 'Passthrough'
passthrough_btn.observe(toggle_passthrough, names='value')

master_mute_btn = widgets.Button(description='Mute All', button_style='danger')
stop_btn = widgets.Button(description='Stop System', button_style='warning')
master_vol_slider = widgets.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.01, description='Master:')

# --- NEW: RESET TO TOH BUTTON ---
reset_toh_btn = widgets.Button(description='Reset ToH', button_style='info', icon='refresh')
def reset_to_toh_action(b):
    for i, slider in enumerate(slider_widgets_list):
        slider.value = float(INITIAL_TOH_GAINS[i])
reset_toh_btn.on_click(reset_to_toh_action)

max_all_btn = widgets.Button(description='Max All', button_style='', icon='arrow-up', layout=widgets.Layout(width='80px'))
min_all_btn = widgets.Button(description='Min All', button_style='', icon='arrow-down', layout=widgets.Layout(width='80px'))
def set_max_all(b):
    for s in slider_widgets_list: s.value = MAX_SLIDER_GAIN
def set_min_all(b):
    for s in slider_widgets_list: s.value = 0.0
max_all_btn.on_click(set_max_all)
min_all_btn.on_click(set_min_all)

out_log = widgets.Output()

def load_file_action(b):
    global file_data, file_play_head
    try:
        data, fs = sf.read(file_path_text.value)
        if data.ndim > 1: data = np.mean(data, axis=1)
        file_data = data.astype(np.float32)
        file_play_head = 0
        file_status_lbl.value = "Loaded"
    except: file_status_lbl.value = "Error (Check Path)"

def stop_stream(b):
    global global_stream, stop_threads
    stop_threads = True
    if global_stream: global_stream.stop(); global_stream.close()
    print("Stopped.")

load_file_btn.on_click(load_file_action)
stop_btn.on_click(stop_stream)
mix_noise_slider.observe(lambda c: globals().update(mix_noise=c['new']), names='value')
mix_tone_slider.observe(lambda c: globals().update(mix_tone=c['new']), names='value')
mix_file_slider.observe(lambda c: globals().update(mix_file=c['new']), names='value')
method_selector.observe(lambda c: globals().update(process_method=c['new']), names='value')
master_vol_slider.observe(lambda c: globals().update(global_volume=c['new']), names='value')

load_file_action(None)

print(f"Starting System ({NUM_BANDS} Bands) | File: {DEFAULT_FILE} | Filter: {DEFAULT_WAV_FILTER}")
try:
    if global_stream: global_stream.stop(); global_stream.close()
    sd.stop(); time.sleep(0.2)
    
    global_stream = sd.OutputStream(samplerate=SAMPLE_RATE, blocksize=BLOCK_SIZE, channels=1, callback=audio_callback)
    global_stream.start()
    
    stop_threads = False
    t = threading.Thread(target=ui_updater)
    t.start()
    
    top_row = widgets.HBox([file_path_text, load_file_btn, file_status_lbl])
    algo_row = widgets.HBox([method_selector, wav_family_selector, wav_name_selector])
    mix_row = widgets.HBox([mix_noise_slider, mix_tone_slider, mix_file_slider])
    
    # Updated Control Row
    ctrl_row = widgets.HBox([passthrough_btn, reset_toh_btn, max_all_btn, min_all_btn, stop_btn, master_vol_slider])
    
    display(widgets.VBox([top_row, algo_row, mix_row, ctrl_row, eq_box, out_log]))

except Exception as e: print(f"Error: {e}")

Starting System (32 Bands) | File: ../data/Metallica_Orion.mp3 | Filter: db5


VBox(children=(HBox(children=(Text(value='../data/Metallica_Orion.mp3', description='File Path:', placeholder=…

In [1]:
import numpy as np
import sounddevice as sd
import ipywidgets as widgets
import pywt
import soundfile as sf
import time
import threading
import json
import os
from scipy.fft import dct, idct
from IPython.display import display, clear_output

# --- Configuration ---
SAMPLE_RATE = 44100
BLOCK_SIZE = 4096
NUM_BANDS = 32
WAVELET_LEVEL = int(np.log2(NUM_BANDS))

# --- DEFAULTS ---
DEFAULT_FILE = "../data/Metallica_Orion.mp3"
DEFAULT_WAV_FAMILY = 'db'
DEFAULT_WAV_FILTER = 'db5'

MAX_SLIDER_GAIN = 1.0 
CURRENT_WAVELET = DEFAULT_WAV_FILTER

# Base Volumes
BASE_NOISE_VOL = 0.2
BASE_TONAL_VOL = 0.05
BASE_FILE_VOL = 0.8 

# --- FREQUENCY CALCULATIONS ---
nyquist = SAMPLE_RATE / 2
bandwidth = nyquist / NUM_BANDS
end_freqs = np.array([(i + 1) * bandwidth for i in range(NUM_BANDS)])
center_freqs = np.array([(i * bandwidth) + (bandwidth / 2) for i in range(NUM_BANDS)])

# --- THRESHOLD OF HEARING MODEL ---
def ToH_model(f):
    f = np.maximum(f, 1e-6) 
    term1 = 3.64 * (f / 1000) ** (-0.8)
    term2 = -6.5 * np.exp(-0.6 * ((f / 1000 - 3.3) ** 2))
    term3 = 10 ** (-3) * (f / 1000) ** 4
    return term1 + term2 + term3

toh_raw = ToH_model(center_freqs)
t_min, t_max = np.min(toh_raw), np.max(toh_raw)
toh_normalized = (toh_raw - t_min) / (t_max - t_min) if t_max > t_min else np.zeros_like(toh_raw)
INITIAL_TOH_GAINS = ((1.0 - toh_normalized) * MAX_SLIDER_GAIN).astype(np.float32)

# --- Global State ---
gains = np.copy(INITIAL_TOH_GAINS)
band_active = np.ones(NUM_BANDS, dtype=bool)
band_variances = np.zeros(NUM_BANDS, dtype=np.float32)

global_volume = 1.0 
slider_widgets_list = [] 
meter_widgets_list = [] 
global_stream = None 
stop_threads = False

mix_noise = 0.0
mix_tone = 0.0
mix_file = 1.0
passthrough_mode = False 
process_method = 'Wavelet'

file_data = None
file_play_head = 0
phases = np.zeros(NUM_BANDS) 
available_families = [f for f in pywt.families(short=True) if f not in ['bior', 'rbio', 'dmey', 'gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor']]

# Master Metering Globals
master_variance = 0.0
master_peak = 0.0


# --- 1. The Audio Callback ---
def audio_callback(outdata, frames, time_info, status):
    global phases, file_play_head, band_variances
    global master_variance, master_peak
    
    if status: print(status)
    
    current_variances = np.zeros(NUM_BANDS)
    
    if global_volume <= 0.001:
        outdata.fill(0)
        band_variances[:] = 0
        master_variance = 0.0
        master_peak = 0.0
        return

    final_mix = np.zeros(frames)
    broadband_signal = np.zeros(frames)
    has_broadband_content = False

    # A. Generate Tone
    if mix_tone > 0.01:
        t = np.arange(frames) / SAMPLE_RATE
        tone_accum = np.zeros(frames)
        for i in range(NUM_BANDS):
            if band_active[i] and gains[i] > 0.01:
                freq = center_freqs[i]
                g = gains[i] * BASE_TONAL_VOL * mix_tone
                tone_accum += g * np.sin(2 * np.pi * freq * t + phases[i])
                phases[i] = (phases[i] + 2 * np.pi * freq * frames / SAMPLE_RATE) % (2 * np.pi)
        final_mix += tone_accum

    # B. Generate Broadband Source
    if mix_noise > 0.01:
        broadband_signal += np.random.uniform(-1, 1, size=frames) * BASE_NOISE_VOL * mix_noise
        has_broadband_content = True

    if mix_file > 0.01 and file_data is not None:
        chunk_size = frames
        remaining = len(file_data) - file_play_head
        if remaining >= chunk_size:
            chunk = file_data[file_play_head : file_play_head + chunk_size]
            file_play_head += chunk_size
        else:
            chunk = np.concatenate((file_data[file_play_head:], file_data[:chunk_size - remaining]))
            file_play_head = chunk_size - remaining
        broadband_signal += chunk * BASE_FILE_VOL * mix_file
        has_broadband_content = True

    # C. Process Broadband
    if has_broadband_content:
        processed = np.zeros(frames)
        
        # Amplified scaling factor specifically for the subband UI meters
        SUBBAND_METER_SCALE = 40.0 
        
        if passthrough_mode:
            processed = broadband_signal
            try:
                spectrum = np.fft.rfft(broadband_signal)
                bins_per_band = len(spectrum) / NUM_BANDS
                for i in range(NUM_BANDS):
                    s, e = int(i*bins_per_band), int((i+1)*bins_per_band)
                    if e > s:
                        raw_rms = np.sqrt(np.mean(np.square(np.abs(spectrum[s:e]))))
                        normalized_rms = raw_rms / (frames / 2.0) 
                        current_variances[i] = normalized_rms * SUBBAND_METER_SCALE
            except: pass
            
        else:
            if process_method == 'Wavelet':
                try:
                    wp = pywt.WaveletPacket(data=broadband_signal, wavelet=CURRENT_WAVELET, mode='symmetric', maxlevel=WAVELET_LEVEL)
                    nodes = wp.get_level(WAVELET_LEVEL, order='freq')
                    count = min(len(nodes), NUM_BANDS)
                    for i in range(count):
                        g = gains[i] if band_active[i] else 0.0
                        processed_node_data = nodes[i].data * g
                        nodes[i].data = processed_node_data
                        if len(processed_node_data) > 0:
                            raw_rms = np.sqrt(np.mean(np.square(np.abs(processed_node_data))))
                            normalized_rms = raw_rms / np.sqrt(frames / 2.0)
                            current_variances[i] = normalized_rms * SUBBAND_METER_SCALE
                    processed = wp.reconstruct(update=False)
                except: processed = broadband_signal

            elif process_method == 'FFT':
                spectrum = np.fft.rfft(broadband_signal)
                bins_per_band = len(spectrum) / NUM_BANDS
                gain_mask = np.zeros(len(spectrum))
                for i in range(NUM_BANDS):
                    s, e = int(i*bins_per_band), int((i+1)*bins_per_band)
                    g = gains[i] if band_active[i] else 0.0
                    gain_mask[s:e] = g
                    processed_band = spectrum[s:e] * g
                    if e > s:
                        raw_rms = np.sqrt(np.mean(np.square(np.abs(processed_band))))
                        normalized_rms = raw_rms / (frames / 2.0)
                        current_variances[i] = normalized_rms * SUBBAND_METER_SCALE
                processed = np.fft.irfft(spectrum * gain_mask)

            elif process_method == 'MDCT':
                spectrum = dct(broadband_signal, type=4, norm='ortho')
                bins_per_band = len(spectrum) / NUM_BANDS
                gain_mask = np.zeros(len(spectrum))
                for i in range(NUM_BANDS):
                    s, e = int(i*bins_per_band), int((i+1)*bins_per_band)
                    g = gains[i] if band_active[i] else 0.0
                    gain_mask[s:e] = g
                    processed_band = spectrum[s:e] * g
                    if e > s:
                        raw_rms = np.sqrt(np.mean(np.square(np.abs(processed_band))))
                        normalized_rms = raw_rms / np.sqrt(frames / 2.0)
                        current_variances[i] = normalized_rms * SUBBAND_METER_SCALE
                processed = idct(spectrum * gain_mask, type=4, norm='ortho')

            if len(processed) != frames:
                processed = np.resize(processed, frames)
        
        final_mix += processed

    # Final Output & Updates
    band_variances[:] = current_variances[:]
    out_signal = final_mix * global_volume
    outdata[:, 0] = out_signal
    
    # Update Master Metering (Keeping a safe scale of 4.0 for the full signal)
    master_rms = np.sqrt(np.mean(np.square(out_signal)))
    master_variance = master_rms * 4.0
    master_peak = np.max(np.abs(out_signal))


# --- 2. Background UI Updater Thread ---
def ui_updater():
    smoothed_vals = np.zeros(NUM_BANDS, dtype=np.float32)
    smoothed_master = 0.0
    DECAY_FACTOR = 0.85 

    while not stop_threads:
        time.sleep(0.1) 
        
        # 1. Update Subband Meters
        raw_vals = np.copy(band_variances)
        for i, meter in enumerate(meter_widgets_list):
            if i < len(raw_vals):
                target_val = min(max(raw_vals[i], 0.0), 1.0)
                if target_val > smoothed_vals[i]:
                    smoothed_vals[i] = target_val
                else:
                    smoothed_vals[i] = smoothed_vals[i] * DECAY_FACTOR
                meter.value = smoothed_vals[i]
                
        # 2. Update Master Meter
        target_master = min(max(master_variance, 0.0), 1.0)
        if target_master > smoothed_master:
            smoothed_master = target_master
        else:
            smoothed_master = smoothed_master * DECAY_FACTOR
            
        master_meter.value = smoothed_master
        
        # Master clipping indicator logic
        if master_peak >= 0.99:
            master_meter.bar_style = 'danger'  
        elif smoothed_master > 0.8:
            master_meter.bar_style = 'warning' 
        else:
            master_meter.bar_style = 'success' 


# --- 3. UI Construction ---
slider_containers = []
band_mute_widgets = []

def update_gain(change): gains[change['owner'].band_index] = change['new']
def toggle_band_mute(change):
    idx = change['owner'].band_index
    band_active[idx] = change['new']
    change['owner'].icon = 'volume-up' if change['new'] else 'volume-off'
    change['owner'].button_style = 'success' if change['new'] else 'danger'
def set_single_max(b): slider_widgets_list[b.band_index].value = MAX_SLIDER_GAIN
def set_single_min(b): slider_widgets_list[b.band_index].value = 0.0

for i in range(NUM_BANDS):
    start_val = float(gains[i])
    
    mute_btn = widgets.ToggleButton(value=True, button_style='success', icon='volume-up', layout=widgets.Layout(width='30px', height='30px', margin='0px'))
    mute_btn.band_index = i
    mute_btn.observe(toggle_band_mute, names='value')
    band_mute_widgets.append(mute_btn)

    gain_box = widgets.BoundedFloatText(value=start_val, min=0.0, max=MAX_SLIDER_GAIN, step=0.01, layout=widgets.Layout(width='45px', height='30px'))
    
    slider = widgets.FloatSlider(value=start_val, min=0.0, max=MAX_SLIDER_GAIN, step=0.01, orientation='vertical', readout=False, layout=widgets.Layout(width='20px', height='250px'))
    slider.band_index = i
    widgets.jslink((slider, 'value'), (gain_box, 'value'))
    slider.observe(update_gain, names='value')
    slider_widgets_list.append(slider)
    
    meter = widgets.FloatProgress(value=0.0, min=0.0, max=1.0, orientation='vertical', bar_style='info', layout=widgets.Layout(width='15px', height='250px', margin='0px 0px 0px 5px'))
    meter_widgets_list.append(meter)
    
    freq_val = end_freqs[i]
    freq_text = f"{freq_val/1000:.1f}k" if freq_val >= 1000 else f"{int(freq_val)}"
    freq_lbl = widgets.Label(value=freq_text, layout=widgets.Layout(width='40px', justify_content='center'))
    
    btn_max = widgets.Button(description='↑', layout=widgets.Layout(width='15px', height='20px', padding='0px'))
    btn_max.band_index = i; btn_max.on_click(set_single_max)
    btn_min = widgets.Button(description='↓', layout=widgets.Layout(width='15px', height='20px', padding='0px'))
    btn_min.band_index = i; btn_min.on_click(set_single_min)
    
    slider_meter_row = widgets.HBox([slider, meter], layout=widgets.Layout(align_items='center'))

    col = widgets.VBox([mute_btn, gain_box, btn_max, slider_meter_row, btn_min, freq_lbl], layout=widgets.Layout(align_items='center', margin='0px 5px 0px 0px', border='1px solid #eee'))
    slider_containers.append(col)

eq_box = widgets.HBox(slider_containers, layout=widgets.Layout(overflow='x scroll', width='100%'))


# --- 4. Controls & Setup ---
file_path_text = widgets.Text(value=DEFAULT_FILE, placeholder='path/to/file.wav', description='File Path:')
load_file_btn = widgets.Button(description='Load', button_style='info', icon='upload')
file_status_lbl = widgets.Label(value="No file")

# Preset UI
preset_name_input = widgets.Text(value='my_eq_preset.json', description='Preset:', layout=widgets.Layout(width='200px'))
save_preset_btn = widgets.Button(description='Save', button_style='success', icon='save', layout=widgets.Layout(width='80px'))
load_preset_btn = widgets.Button(description='Load', button_style='info', icon='folder-open', layout=widgets.Layout(width='80px'))
preset_status_lbl = widgets.Label(value="")

def save_preset_action(b):
    preset_data = {
        'gains': [float(s.value) for s in slider_widgets_list],
        'band_active': [bool(m.value) for m in band_mute_widgets],
        'method': method_selector.value,
        'wavelet_family': wav_family_selector.value,
        'wavelet_name': wav_name_selector.value,
        'mix_noise': float(mix_noise_slider.value),
        'mix_tone': float(mix_tone_slider.value),
        'mix_file': float(mix_file_slider.value),
        'master_vol': float(master_vol_slider.value)
    }
    try:
        with open(preset_name_input.value, 'w') as f:
            json.dump(preset_data, f, indent=4)
        preset_status_lbl.value = "Preset saved!"
    except Exception as e: preset_status_lbl.value = f"Save Error: {str(e)[:20]}"

def load_preset_action(b):
    filepath = preset_name_input.value
    if not os.path.exists(filepath):
        preset_status_lbl.value = "File not found!"
        return
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
        if 'gains' in data:
            for i, g in enumerate(data['gains']):
                if i < len(slider_widgets_list): slider_widgets_list[i].value = g
        if 'band_active' in data:
            for i, active_state in enumerate(data['band_active']):
                if i < len(band_mute_widgets): band_mute_widgets[i].value = active_state
        if 'method' in data: method_selector.value = data['method']
        if 'wavelet_family' in data: wav_family_selector.value = data['wavelet_family']
        if 'wavelet_name' in data and data['wavelet_name'] in wav_name_selector.options: 
            wav_name_selector.value = data['wavelet_name']
        if 'mix_noise' in data: mix_noise_slider.value = data['mix_noise']
        if 'mix_tone' in data: mix_tone_slider.value = data['mix_tone']
        if 'mix_file' in data: mix_file_slider.value = data['mix_file']
        if 'master_vol' in data: master_vol_slider.value = data['master_vol']
        preset_status_lbl.value = "Preset loaded!"
    except Exception as e: preset_status_lbl.value = f"Load Error: {str(e)[:20]}"

save_preset_btn.on_click(save_preset_action)
load_preset_btn.on_click(load_preset_action)

method_selector = widgets.Dropdown(options=['Wavelet', 'FFT', 'MDCT'], value='Wavelet', description='Method:')
wav_family_selector = widgets.Dropdown(options=available_families, value=DEFAULT_WAV_FAMILY, description='Family:')
initial_options = pywt.wavelist(DEFAULT_WAV_FAMILY)
wav_name_selector = widgets.Dropdown(options=initial_options, value=DEFAULT_WAV_FILTER, description='Filter:')

def update_wav_names(change): 
    wav_name_selector.options = pywt.wavelist(change['new'])
    wav_name_selector.value = wav_name_selector.options[0]
def update_current_wavelet(change): 
    global CURRENT_WAVELET; CURRENT_WAVELET = change['new']

wav_family_selector.observe(update_wav_names, names='value')
wav_name_selector.observe(update_current_wavelet, names='value')

mix_noise_slider = widgets.FloatSlider(value=mix_noise, min=0.0, max=1.0, step=0.01, description='Noise:')
mix_tone_slider = widgets.FloatSlider(value=mix_tone, min=0.0, max=1.0, step=0.01, description='Tone:')
mix_file_slider = widgets.FloatSlider(value=mix_file, min=0.0, max=1.0, step=0.01, description='File:')

passthrough_btn = widgets.ToggleButton(value=False, description='Passthrough', button_style='', icon='exchange', layout=widgets.Layout(width='120px'))
def toggle_passthrough(change):
    global passthrough_mode; passthrough_mode = change['new']
    passthrough_btn.button_style = 'warning' if passthrough_mode else ''
    passthrough_btn.description = 'PT Active' if passthrough_mode else 'Passthrough'
passthrough_btn.observe(toggle_passthrough, names='value')

stop_btn = widgets.Button(description='Stop System', button_style='warning')
master_vol_slider = widgets.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.01, description='Master:')

# Master Meter
master_meter = widgets.FloatProgress(value=0.0, min=0.0, max=1.0, orientation='horizontal', bar_style='success', layout=widgets.Layout(width='150px', height='20px', margin='0px 0px 0px 10px'))

reset_toh_btn = widgets.Button(description='Reset ToH', button_style='info', icon='refresh')
def reset_to_toh_action(b):
    for i, slider in enumerate(slider_widgets_list): slider.value = float(INITIAL_TOH_GAINS[i])
reset_toh_btn.on_click(reset_to_toh_action)

max_all_btn = widgets.Button(description='Max All', button_style='', icon='arrow-up', layout=widgets.Layout(width='80px'))
min_all_btn = widgets.Button(description='Min All', button_style='', icon='arrow-down', layout=widgets.Layout(width='80px'))
def set_max_all(b):
    for s in slider_widgets_list: s.value = MAX_SLIDER_GAIN
def set_min_all(b):
    for s in slider_widgets_list: s.value = 0.0
max_all_btn.on_click(set_max_all)
min_all_btn.on_click(set_min_all)

out_log = widgets.Output()

def load_file_action(b):
    global file_data, file_play_head
    try:
        data, fs = sf.read(file_path_text.value)
        if data.ndim > 1: data = np.mean(data, axis=1)
        file_data = data.astype(np.float32)
        file_play_head = 0
        file_status_lbl.value = "Loaded"
    except: file_status_lbl.value = "Error (Check Path)"

def stop_stream(b):
    global global_stream, stop_threads
    stop_threads = True
    if global_stream: global_stream.stop(); global_stream.close()
    print("Stopped.")

load_file_btn.on_click(load_file_action)
stop_btn.on_click(stop_stream)
mix_noise_slider.observe(lambda c: globals().update(mix_noise=c['new']), names='value')
mix_tone_slider.observe(lambda c: globals().update(mix_tone=c['new']), names='value')
mix_file_slider.observe(lambda c: globals().update(mix_file=c['new']), names='value')
method_selector.observe(lambda c: globals().update(process_method=c['new']), names='value')
master_vol_slider.observe(lambda c: globals().update(global_volume=c['new']), names='value')

load_file_action(None)
print(f"Starting System ({NUM_BANDS} Bands) | File: {DEFAULT_FILE} | Filter: {DEFAULT_WAV_FILTER}")

# --- 5. Start Engine & Display UI ---
try:
    if global_stream: global_stream.stop(); global_stream.close()
    sd.stop(); time.sleep(0.2)
    
    global_stream = sd.OutputStream(samplerate=SAMPLE_RATE, blocksize=BLOCK_SIZE, channels=1, callback=audio_callback)
    global_stream.start()
    
    stop_threads = False
    t = threading.Thread(target=ui_updater)
    t.start()
    
    top_row = widgets.HBox([file_path_text, load_file_btn, file_status_lbl])
    preset_row = widgets.HBox([preset_name_input, save_preset_btn, load_preset_btn, preset_status_lbl])
    algo_row = widgets.HBox([method_selector, wav_family_selector, wav_name_selector])
    mix_row = widgets.HBox([mix_noise_slider, mix_tone_slider, mix_file_slider])
    ctrl_row = widgets.HBox([passthrough_btn, reset_toh_btn, max_all_btn, min_all_btn, stop_btn, widgets.HBox([master_vol_slider, master_meter], layout=widgets.Layout(align_items='center'))])
    
    display(widgets.VBox([top_row, preset_row, algo_row, mix_row, ctrl_row, eq_box, out_log]))

except Exception as e: print(f"Error: {e}")

Starting System (32 Bands) | File: ../data/Metallica_Orion.mp3 | Filter: db5


VBox(children=(HBox(children=(Text(value='../data/Metallica_Orion.mp3', description='File Path:', placeholder=…