In [3]:
import pandas as pd
import torchaudio
from pathlib import Path
import sys
import numpy as np
import random
import torch
import IPython.display as ipd
import matplotlib.pyplot as plt
from ipywidgets import widgets, Layout, VBox, HBox, Button, HTML as HTMLWidget
from IPython.display import display, clear_output, HTML
import io

# Set random seeds for reproducibility
SEED = 0
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# Setup paths
repo_root = Path.cwd().parent
sys.path.insert(0, str(repo_root / "src"))

# Import your modules
from utils.audio_dataset_loader import preprocess_audio
from deep_learning.gtcrn import GTCRN
from dsp_algorithms.mband_var import mband
from dsp_algorithms.wiener_as import wiener_filter

# Initialize models
print("Loading models...")
device = torch.device("cpu")
gtcrn_model = GTCRN().eval()
ckpt_path = repo_root / "src" / "deep_learning" / "gtcrn" / "gtcrn_main" / "checkpoints" / "model_trained_on_dns3.tar"
ckpt = torch.load(ckpt_path, map_location=device)
gtcrn_model.load_state_dict(ckpt['model'])
print("‚úì Models loaded successfully!")

# Define paths
noise_path = repo_root / "sound_data" / "raw" / "NOIZEUS_NOISE_DATASET"
clean_sound_path = repo_root / "sound_data" / "raw" / "EARS_DATASET" / "p092"

# Noise file mapping
NOISE_FILES = {
    "Cafeteria Babble": "cafeteria_babble.wav",
    "Street Noise - Downtown": "Street Noise_downtown.wav",
    "Street Noise": "Street Noise.wav",
    "Car Noise (60mph)": "Car Noise_60mph.wav",
    "Car Idle (40mph)": "Car Noise_Idle Noise_40mph.wav",
    "Car Idle (60mph)": "Car Noise_Idle Noise_60mph.wav",
    "Construction - Crane": "Construction_Crane_Moving.wav",
    "Construction - Drilling": "Construction_Drilling.wav",
    "Construction - Jackhammer 1": "Construction_Jackhammer1.wav",
    "Construction - Jackhammer 2": "Construction_Jackhammer2.wav",
    "Construction - Trucks": "Construction_Trucks_Unloading.wav",
    "Inside Flight": "Inside Flight.wav",
    "Inside Train 1": "Inside Train_1.wav",
    "Inside Train 2": "Inside Train_2.wav",
    "Inside Train 3": "Inside Train_3.wav",
    "PC Fan": "PC Fan Noise.wav",
    "SSN IEEE": "SSN_IEEE.wav",
    "Train 1": "Train1.wav",
    "Train 2": "Train2.wav",
    "Water Cooler": "Water Cooler.wav"
}

# Speech file mapping
SPEECH_FILES = {
    "Amazement": "emo_amazement_freeform.wav",
    "Anger": "emo_anger_freeform.wav",
    "Disgust": "emo_disgust_freeform.wav",
    "Fear": "emo_fear_freeform.wav",
    # "Happiness": "emo_happiness_freeform.wav",
    "Sadness": "emo_sadness_freeform.wav"
}

# Global variables to store processed audio
processed_audio = {}

def calculate_metrics(clean, degraded, sample_rate):
    """Calculate speech quality metrics"""
    from pesq import pesq
    from pystoi import stoi
    
    # Ensure same length
    min_len = min(len(clean), len(degraded))
    clean = clean[:min_len]
    degraded = degraded[:min_len]
    
    # Calculate PESQ (Perceptual Evaluation of Speech Quality)
    try:
        pesq_score = pesq(sample_rate, clean, degraded, 'wb')  # wideband
    except:
        pesq_score = None
    
    # Calculate STOI (Short-Time Objective Intelligibility)
    try:
        stoi_score = stoi(clean, degraded, sample_rate, extended=False)
    except:
        stoi_score = None
    
    # Calculate SNR
    signal_power = np.sum(clean ** 2)
    noise = degraded - clean
    noise_power = np.sum(noise ** 2)
    
    if noise_power > 0:
        snr = 10 * np.log10(signal_power / noise_power)
    else:
        snr = float('inf')
    
    return {
        'pesq': pesq_score,
        'stoi': stoi_score,
        'snr': snr
    }

def process_audio_pipeline(clean_waveform, noise_waveform, noisy_speech, sample_rate):
    """Process audio through MBAND, Wiener, and GTCRN+MBAND pipelines"""
    
    # GTCRN inference
    input_stft = torch.stft(noisy_speech, 512, 256, 512, torch.hann_window(512).pow(0.5), return_complex=True)
    input_stft = torch.view_as_real(input_stft)
    
    with torch.no_grad():
        output_stft = gtcrn_model(input_stft[None])[0]
    
    output_stft = torch.complex(output_stft[..., 0], output_stft[..., 1])
    gtcrn_enhanced = torch.istft(output_stft, 512, 256, 512, torch.hann_window(512).pow(0.5)).detach().cpu()
    
    # GTCRN + MBAND
    gtcrn_mband_enh, _ = mband(
        noisy_audio=gtcrn_enhanced.clone(),
        fs=sample_rate,
        Nband=8,
        Freq_spacing='linear',
        FRMSZ=20,
        OVLP=75,
        AVRGING=1,
        Noisefr=1,
        FLOOR=0.7,
        VAD=1,
    )
    
    # GTCRN + Wiener Filter
    gtcrn_wiener_enh, _ = wiener_filter(
        noisy_audio=gtcrn_enhanced.clone(),
        fs=sample_rate,
        frame_dur_ms=25,
        mu=0.95,
        a_dd=0.95,
        eta=0.15
    )
    
    return gtcrn_mband_enh, gtcrn_wiener_enh

def plot_metrics(metrics_dict, mode):
    """Create interactive metrics comparison plot"""
    
    # Prepare data
    methods = []
    pesq_scores = []
    stoi_scores = []
    snr_scores = []
    
    for method, metrics in metrics_dict.items():
        methods.append(method)
        pesq_scores.append(metrics['pesq'] if metrics['pesq'] is not None else 0)
        stoi_scores.append(metrics['stoi'] if metrics['stoi'] is not None else 0)
        snr_scores.append(metrics['snr'])
    
    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    
    # Color scheme
    colors = {
        'Noisy': '#E53935',
        'GTCRN+MBAND': '#1976D2', 
        'GTCRN+Wiener': '#7B1FA2'
    }
    
    bar_colors = [colors[m] for m in methods]
    
    # PESQ plot
    axes[0].bar(methods, pesq_scores, color=bar_colors, alpha=0.8, edgecolor='white', linewidth=2)
    axes[0].set_ylabel('PESQ Score', fontweight='600', fontsize=11)
    axes[0].set_title('Perceptual Speech Quality', fontweight='600', fontsize=12, pad=15)
    axes[0].set_ylim([0, 5])
    axes[0].grid(axis='y', alpha=0.3, linestyle='--')
    axes[0].set_facecolor('#FAFAFA')
    for spine in axes[0].spines.values():
        spine.set_edgecolor('#E0E0E0')
    
    # Add value labels on bars
    for i, v in enumerate(pesq_scores):
        if v > 0:
            axes[0].text(i, v + 0.1, f'{v:.2f}', ha='center', fontweight='600', fontsize=10)
    
    # STOI plot
    axes[1].bar(methods, stoi_scores, color=bar_colors, alpha=0.8, edgecolor='white', linewidth=2)
    axes[1].set_ylabel('STOI Score', fontweight='600', fontsize=11)
    axes[1].set_title('Speech Intelligibility', fontweight='600', fontsize=12, pad=15)
    axes[1].set_ylim([0, 1])
    axes[1].grid(axis='y', alpha=0.3, linestyle='--')
    axes[1].set_facecolor('#FAFAFA')
    for spine in axes[1].spines.values():
        spine.set_edgecolor('#E0E0E0')
    
    for i, v in enumerate(stoi_scores):
        if v > 0:
            axes[1].text(i, v + 0.02, f'{v:.2f}', ha='center', fontweight='600', fontsize=10)
    
    # SNR plot
    axes[2].bar(methods, snr_scores, color=bar_colors, alpha=0.8, edgecolor='white', linewidth=2)
    axes[2].set_ylabel('SNR (dB)', fontweight='600', fontsize=11)
    axes[2].set_title('Signal-to-Noise Ratio', fontweight='600', fontsize=12, pad=15)
    axes[2].grid(axis='y', alpha=0.3, linestyle='--')
    axes[2].set_facecolor('#FAFAFA')
    for spine in axes[2].spines.values():
        spine.set_edgecolor('#E0E0E0')
    
    for i, v in enumerate(snr_scores):
        axes[2].text(i, v + 0.5, f'{v:.1f}', ha='center', fontweight='600', fontsize=10)
    
    # Rotate x labels
    for ax in axes:
        ax.tick_params(axis='x', rotation=15)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    return fig


def on_process_button_clicked(b):
    """Handle process button click"""
    global processed_audio
    
    with output_area:
        clear_output(wait=True)
        
        # Show processing message
        display(HTML('''
        <div style="padding: 20px; background: linear-gradient(135deg, #FFF8E1 0%, #FFE082 100%); 
                    border-radius: 12px; margin: 15px 0; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
            <div style="display: flex; align-items: center; gap: 12px;">
                <div style="font-size: 24px;">‚è≥</div>
                <div>
                    <div style="font-weight: 600; color: #F57F17; font-size: 1.1em;">Processing Audio</div>
                    <div style="color: #F9A825; margin-top: 4px;">This may take 10-30 seconds...</div>
                </div>
            </div>
        </div>
        '''))
        
        try:
            # Get selected options
            if use_upload_checkbox.value and uploaded_audio_data is not None:
                # Use uploaded file
                noisy_speech = uploaded_audio_data['waveform']
                sample_rate = uploaded_audio_data['sample_rate']
                clean_waveform = noisy_speech.clone()  # No clean reference for uploaded
                noise_waveform = None
                mode = "uploaded"
            else:
                # Use selected files
                noise_file = NOISE_FILES[noise_dropdown.value]
                speech_file = SPEECH_FILES[speech_dropdown.value]
                snr_db = snr_slider.value
                
                noise_file_path = noise_path / noise_file
                speech_file_path = clean_sound_path / speech_file
                
                # Preprocess audio
                clean_waveform, noise_waveform, noisy_speech, sample_rate = preprocess_audio(
                    clean_speech=speech_file_path,
                    noisy_audio=noise_file_path,
                    snr_db=snr_db
                )
                mode = "generated"
            
            # Process through pipelines
            gtcrn_mband_enhanced, gtcrn_wiener_enhanced = process_audio_pipeline(
                clean_waveform, noise_waveform, noisy_speech, sample_rate
            )
            
            # Trim to same length
            lengths = [len(noisy_speech.squeeze()), len(gtcrn_mband_enhanced.squeeze()), 
                      len(gtcrn_wiener_enhanced.squeeze())]
            
            if mode == "generated":
                lengths.append(len(clean_waveform.squeeze()))
                
            min_len = min(lengths)
            
            noisy_trimmed = noisy_speech.squeeze()[:min_len]
            gtcrn_mband_trimmed = gtcrn_mband_enhanced.squeeze()[:min_len]
            gtcrn_wiener_trimmed = gtcrn_wiener_enhanced.squeeze()[:min_len]
            
            if mode == "generated":
                clean_trimmed = clean_waveform.squeeze()[:min_len]
            else:
                clean_trimmed = None

            # Calculate metrics if clean reference available
            metrics_dict = None
            if mode == "generated":
                try:
                    metrics_dict = {
                        'Noisy': calculate_metrics(
                            clean_trimmed.numpy(), 
                            noisy_trimmed.numpy(), 
                            sample_rate
                        ),
                        'GTCRN+MBAND': calculate_metrics(
                            clean_trimmed.numpy(), 
                            gtcrn_mband_trimmed.numpy(), 
                            sample_rate
                        ),
                        'GTCRN+Wiener': calculate_metrics(
                            clean_trimmed.numpy(), 
                            gtcrn_wiener_trimmed.numpy(), 
                            sample_rate
                        )
                    }
                except Exception as e:
                    print(f"Warning: Could not calculate metrics: {e}")
                    metrics_dict = None
            
            # Store results
            processed_audio = {
                'noisy': noisy_trimmed.numpy(),
                'gtcrn_mband': gtcrn_mband_trimmed.numpy(),
                'gtcrn_wiener': gtcrn_wiener_trimmed.numpy(),
                'clean': clean_trimmed.numpy() if mode == "generated" else None,
                'sample_rate': sample_rate,
                'mode': mode, 
                'metrics': metrics_dict
            }
            
            clear_output(wait=True)
            
            # Success message
            display(HTML('''
            <div style="padding: 20px; background: linear-gradient(135deg, #E8F5E9 0%, #A5D6A7 100%); 
                        border-radius: 12px; margin: 15px 0; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
                <div style="display: flex; align-items: center; gap: 12px;">
                    <div style="font-size: 24px;">‚úÖ</div>
                    <div>
                        <div style="font-weight: 600; color: #2E7D32; font-size: 1.1em;">Processing Complete</div>
                        <div style="color: #43A047; margin-top: 4px;">Listen to the enhanced audio below</div>
                    </div>
                </div>
            </div>
            '''))

             # Display metrics first if available
            if metrics_dict is not None:
                display(HTML('''
                <div style="margin: 30px 0 20px 0;">
                    <h2 style="color: #1976D2; font-weight: 600; font-size: 1.5em; margin: 0;">
                        Speech Quality Metrics
                    </h2>
                    <div style="height: 3px; width: 60px; background: linear-gradient(90deg, #1976D2, #64B5F6); 
                                border-radius: 2px; margin-top: 8px;"></div>
                    <p style="color: #666; margin-top: 10px; font-size: 0.95em;">
                        Objective evaluation of speech enhancement performance
                    </p>
                </div>
                '''))
                
                fig = plot_metrics(metrics_dict, mode)
                plt.show()
                
                # Metrics explanation
                display(HTML('''
                <div style="background: white; padding: 18px; border-radius: 10px; margin: 15px 0; 
                            border-left: 4px solid #1976D2; box-shadow: 0 1px 4px rgba(0,0,0,0.08);">
                    <div style="color: #424242; font-size: 0.92em; line-height: 1.7;">
                        <strong style="color: #1976D2;">PESQ:</strong> Perceptual quality (1-5, higher is better) | 
                        <strong style="color: #1976D2;">STOI:</strong> Intelligibility (0-1, higher is better) | 
                        <strong style="color: #1976D2;">SNR:</strong> Signal-to-noise ratio (dB, higher is better)
                    </div>
                </div>
                '''))
        
            
            # Display waveforms
            display(HTML('''
            <div style="margin: 30px 0 20px 0;">
                <h2 style="color: #1976D2; font-weight: 600; font-size: 1.5em; margin: 0;">
                    üìä Waveform Comparison
                </h2>
                <div style="height: 3px; width: 60px; background: linear-gradient(90deg, #1976D2, #64B5F6); 
                            border-radius: 2px; margin-top: 8px;"></div>
            </div>
            '''))
            
            # Create figure with proper ordering
            num_plots = 4 if mode == "generated" else 3
            fig, axes = plt.subplots(num_plots, 1, figsize=(15, 3*num_plots))
            if num_plots == 1:
                axes = [axes]
            
            time_axis = np.arange(min_len) / sample_rate
            
            plot_idx = 0
            
            # 1. Clean Speech (if available)
            if mode == "generated":
                axes[plot_idx].plot(time_axis, clean_trimmed.numpy(), color='#43A047', alpha=0.8, linewidth=0.9)
                axes[plot_idx].set_title('1. Clean Speech (Reference)', fontsize=13, fontweight='600', pad=12)
                axes[plot_idx].set_ylabel('Amplitude', fontweight='500')
                axes[plot_idx].grid(True, alpha=0.2, linestyle='--')
                axes[plot_idx].set_facecolor('#FAFAFA')
                axes[plot_idx].spines['top'].set_visible(False)
                axes[plot_idx].spines['right'].set_visible(False)
                plot_idx += 1
            
            # 2. Noisy Speech
            axes[plot_idx].plot(time_axis, noisy_trimmed.numpy(), color='#E53935', alpha=0.8, linewidth=0.9)
            axes[plot_idx].set_title(f'{plot_idx+1}. Noisy Speech', fontsize=13, fontweight='600', pad=12)
            axes[plot_idx].set_ylabel('Amplitude', fontweight='500')
            axes[plot_idx].grid(True, alpha=0.2, linestyle='--')
            axes[plot_idx].set_facecolor('#FAFAFA')
            axes[plot_idx].spines['top'].set_visible(False)
            axes[plot_idx].spines['right'].set_visible(False)
            plot_idx += 1
            
            # 3. GTCRN + MBAND
            axes[plot_idx].plot(time_axis, gtcrn_mband_trimmed.numpy(), color='#1976D2', alpha=0.8, linewidth=0.9)
            axes[plot_idx].set_title(f'{plot_idx+1}. GTCRN + MBAND Enhanced', fontsize=13, fontweight='600', pad=12)
            axes[plot_idx].set_ylabel('Amplitude', fontweight='500')
            axes[plot_idx].grid(True, alpha=0.2, linestyle='--')
            axes[plot_idx].set_facecolor('#FAFAFA')
            axes[plot_idx].spines['top'].set_visible(False)
            axes[plot_idx].spines['right'].set_visible(False)
            plot_idx += 1
            
            # 4. GTCRN + Wiener
            axes[plot_idx].plot(time_axis, gtcrn_wiener_trimmed.numpy(), color='#7B1FA2', alpha=0.8, linewidth=0.9)
            axes[plot_idx].set_title(f'{plot_idx+1}. GTCRN + Wiener Enhanced', fontsize=13, fontweight='600', pad=12)
            axes[plot_idx].set_ylabel('Amplitude', fontweight='500')
            axes[plot_idx].set_xlabel('Time (s)', fontweight='500')
            axes[plot_idx].grid(True, alpha=0.2, linestyle='--')
            axes[plot_idx].set_facecolor('#FAFAFA')
            axes[plot_idx].spines['top'].set_visible(False)
            axes[plot_idx].spines['right'].set_visible(False)
            
            plt.tight_layout(pad=2.0)
            plt.show()
            
            # Display audio players
            display(HTML('''
            <div style="margin: 30px 0 20px 0;">
                <h2 style="color: #1976D2; font-weight: 600; font-size: 1.5em; margin: 0;">
                    üéß Audio Playback
                </h2>
                <div style="height: 3px; width: 60px; background: linear-gradient(90deg, #1976D2, #64B5F6); 
                            border-radius: 2px; margin-top: 8px;"></div>
            </div>
            '''))
            
            play_idx = 1
            
            # Clean reference (if available)
            # if mode == "generated":
            #     display(HTML(f'''
            #     <div style="background: white; padding: 20px; border-radius: 12px; margin: 15px 0; 
            #                 box-shadow: 0 2px 12px rgba(0,0,0,0.08); border-left: 4px solid #43A047;">
            #         <div style="font-weight: 600; font-size: 1.15em; color: #2E7D32; margin-bottom: 8px;">
            #             {play_idx}. Clean Speech (Reference)
            #         </div>
            #         <div style="color: #666; font-size: 0.95em; margin-bottom: 12px;">
            #             Original clean speech without noise
            #         </div>
            #     '''))
            #     display(ipd.Audio(clean_trimmed.numpy(), rate=sample_rate))
            #     display(HTML('</div>'))
            #     play_idx += 1
            
            # Noisy
            display(HTML(f'''
            <div style="background: white; padding: 20px; border-radius: 12px; margin: 15px 0; 
                        box-shadow: 0 2px 12px rgba(0,0,0,0.08); border-left: 4px solid #E53935;">
                <div style="font-weight: 600; font-size: 1.15em; color: #C62828; margin-bottom: 8px;">
                    {play_idx}. Noisy Speech
                </div>
                <div style="color: #666; font-size: 0.95em; margin-bottom: 12px;">
                    Speech degraded by environmental noise
                </div>
            '''))
            display(ipd.Audio(noisy_trimmed.numpy(), rate=sample_rate))
            display(HTML('</div>'))
            play_idx += 1
            
            # GTCRN + MBAND
            display(HTML(f'''
            <div style="background: white; padding: 20px; border-radius: 12px; margin: 15px 0; 
                        box-shadow: 0 2px 12px rgba(0,0,0,0.08); border-left: 4px solid #1976D2;">
                <div style="font-weight: 600; font-size: 1.15em; color: #1565C0; margin-bottom: 8px;">
                    {play_idx}. GTCRN + MBAND Enhanced
                </div>
                <div style="color: #666; font-size: 0.95em; margin-bottom: 12px;">
                    Deep learning (GTCRN) + Multi-band spectral subtraction
                </div>
            '''))
            display(ipd.Audio(gtcrn_mband_trimmed.numpy(), rate=sample_rate))
            display(HTML('</div>'))
            play_idx += 1
            
            # GTCRN + Wiener
            display(HTML(f'''
            <div style="background: white; padding: 20px; border-radius: 12px; margin: 15px 0; 
                        box-shadow: 0 2px 12px rgba(0,0,0,0.08); border-left: 4px solid #7B1FA2;">
                <div style="font-weight: 600; font-size: 1.15em; color: #6A1B9A; margin-bottom: 8px;">
                    {play_idx}. GTCRN + Wiener Enhanced
                </div>
                <div style="color: #666; font-size: 0.95em; margin-bottom: 12px;">
                    Deep learning (GTCRN) + Wiener filter post-processing
                </div>
            '''))
            display(ipd.Audio(gtcrn_wiener_trimmed.numpy(), rate=sample_rate))
            display(HTML('</div>'))
            
        except Exception as e:
            clear_output(wait=True)
            display(HTML(f'''
            <div style="padding: 20px; background: linear-gradient(135deg, #FFEBEE 0%, #EF9A9A 100%); 
                        border-radius: 12px; margin: 15px 0; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
                <div style="display: flex; align-items: center; gap: 12px;">
                    <div style="font-size: 24px;">‚ùå</div>
                    <div>
                        <div style="font-weight: 600; color: #C62828; font-size: 1.1em;">Processing Error</div>
                        <div style="color: #E53935; margin-top: 4px;">{str(e)}</div>
                    </div>
                </div>
            </div>
            '''))

# Store uploaded audio
uploaded_audio_data = None

def on_upload_change(change):
    """Handle file upload"""
    global uploaded_audio_data
    
    with upload_output:
        clear_output(wait=True)
        
        if change['new']:
            try:
                # Get the first uploaded file
                file_info = change['new'][0] if isinstance(change['new'], (list, tuple)) else change['new']
                
                # Handle both old and new ipywidgets formats
                if isinstance(file_info, dict):
                    content = file_info['content']
                else:
                    content = file_info
                
                # Save temporarily
                temp_path = Path("temp_uploaded.wav")
                
                # Convert bytes to file
                if isinstance(content, bytes):
                    with open(temp_path, 'wb') as f:
                        f.write(content)
                else:
                    # If it's already a file-like object
                    with open(temp_path, 'wb') as f:
                        f.write(content.read())
                
                # Load audio
                waveform, sample_rate = torchaudio.load(temp_path)
                
                # Convert to mono if stereo
                if waveform.shape[0] > 1:
                    waveform = torch.mean(waveform, dim=0, keepdim=True)
                
                # Resample to 16kHz if needed
                if sample_rate != 16000:
                    resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                    waveform = resampler(waveform)
                    sample_rate = 16000
                
                uploaded_audio_data = {
                    'waveform': waveform,
                    'sample_rate': sample_rate
                }
                
                display(HTML(f'''
                <div style="padding: 16px; background: linear-gradient(135deg, #E8F5E9 0%, #A5D6A7 100%); 
                            border-radius: 10px; border-left: 4px solid #43A047;">
                    <div style="font-weight: 600; color: #2E7D32; margin-bottom: 6px;">‚úì File Uploaded Successfully</div>
                    <div style="color: #43A047; font-size: 0.95em;">
                        Duration: {len(waveform[0])/sample_rate:.2f}s | Sample Rate: {sample_rate} Hz
                    </div>
                </div>
                '''))
                
                # Clean up
                temp_path.unlink()
                
            except Exception as e:
                display(HTML(f'''
                <div style="padding: 16px; background: linear-gradient(135deg, #FFEBEE 0%, #EF9A9A 100%); 
                            border-radius: 10px; border-left: 4px solid #E53935;">
                    <div style="font-weight: 600; color: #C62828; margin-bottom: 6px;">Error Loading File</div>
                    <div style="color: #E53935; font-size: 0.95em;">{str(e)}</div>
                </div>
                '''))

# Create GUI components with improved styling
style = {'description_width': '160px'}
layout = Layout(width='500px')

# Header
display(HTML("""
<div style="background: linear-gradient(135deg, #1976D2 0%, #1565C0 100%); 
            padding: 40px; border-radius: 16px; margin-bottom: 30px; 
            box-shadow: 0 8px 32px rgba(25, 118, 210, 0.3);">
    <h1 style="margin: 0; font-size: 2.5em; color: white; font-weight: 600; letter-spacing: -0.5px;">
        üéµ Speech Enhancement System
    </h1>
    <p style="margin: 12px 0 0 0; font-size: 1.15em; color: rgba(255,255,255,0.95); font-weight: 400;">
        Advanced Audio Processing with Deep Learning & DSP Algorithms
    </p>
</div>
"""))

# Mode selection section
display(HTML('''
<div style="margin: 25px 0 15px 0;">
    <h2 style="color: #424242; font-weight: 600; font-size: 1.3em; margin: 0;">
        üìÅ Input Selection
    </h2>
    <div style="height: 3px; width: 50px; background: linear-gradient(90deg, #1976D2, #64B5F6); 
                border-radius: 2px; margin-top: 8px;"></div>
</div>
'''))

use_upload_checkbox = widgets.Checkbox(
    value=False,
    description='Upload your own noisy audio file',
    style=style,
    layout=layout,
    indent=False
)

upload_widget = widgets.FileUpload(
    accept='.wav,.mp3',
    multiple=False,
    description='Select Audio File',
    style=style,
    layout=layout
)
upload_widget.observe(on_upload_change, names='value')

upload_output = widgets.Output()

display(use_upload_checkbox)
display(upload_widget)
display(upload_output)

# Configuration section
display(HTML('''
<div style="margin: 30px 0 15px 0;">
    <h2 style="color: #424242; font-weight: 600; font-size: 1.3em; margin: 0;">
        üéõÔ∏è Audio Configuration
    </h2>
    <div style="height: 3px; width: 50px; background: linear-gradient(90deg, #1976D2, #64B5F6); 
                border-radius: 2px; margin-top: 8px;"></div>
    <p style="color: #757575; margin: 10px 0 15px 0; font-size: 0.95em;">
        Configure noise and speech settings (used when not uploading a file)
    </p>
</div>
'''))

noise_dropdown = widgets.Dropdown(
    options=list(NOISE_FILES.keys()),
    value='Street Noise - Downtown',
    description='Noise Type:',
    style=style,
    layout=layout
)

speech_dropdown = widgets.Dropdown(
    options=list(SPEECH_FILES.keys()),
    value='Amazement',
    description='Speech Emotion:',
    style=style,
    layout=layout
)

snr_slider = widgets.IntSlider(
    value=0,
    min=-10,
    max=20,
    step=1,
    description='SNR Level (dB):',
    style=style,
    layout=layout,
    continuous_update=False
)

display(noise_dropdown)
display(speech_dropdown)
display(snr_slider)
display(HTML('''
<p style="color: #757575; font-size: 0.9em; margin: 5px 0 0 165px; line-height: 1.5;">
    <strong>Lower values:</strong> More challenging noise conditions<br>
    <strong>Higher values:</strong> Cleaner signal with less noise
</p>
'''))

# Process button
display(HTML('<div style="margin: 30px 0 20px 0;">'))
process_button = widgets.Button(
    description='üéõÔ∏è  Process Audio',
    button_style='',
    layout=Layout(width='500px', height='56px'),
    style={'button_color': '#1976D2', 'font_weight': '600', 'font_size': '16px'}
)
process_button.on_click(on_process_button_clicked)
display(process_button)
display(HTML('</div>'))

# Output area
output_area = widgets.Output()
display(output_area)

# Info box
display(HTML("""
<div style="background: linear-gradient(135deg, #E3F2FD 0%, #BBDEFB 100%); 
            padding: 20px; border-left: 4px solid #1976D2; 
            border-radius: 12px; margin-top: 30px; box-shadow: 0 2px 8px rgba(0,0,0,0.06);">
    <div style="color: #1565C0; line-height: 1.8;">
        <div style="font-weight: 600; font-size: 1.1em; margin-bottom: 10px;">üí° Usage Instructions</div>
        <div style="font-size: 0.95em;">
            <strong>1.</strong> Choose input method: Upload your noisy audio OR use predefined settings<br>
            <strong>2.</strong> Click "Process Audio" and wait 10-30 seconds for processing<br>
            <strong>3.</strong> Compare the enhancement techniques in waveforms and audio playback<br>
            <strong>4.</strong> GTCRN + MBAND and GTCRN + Wiener offer state-of-the-art noise reduction
        </div>
    </div>
</div>
"""))

Loading models...
‚úì Models loaded successfully!


Checkbox(value=False, description='Upload your own noisy audio file', indent=False, layout=Layout(width='500px‚Ä¶

FileUpload(value=(), accept='.wav,.mp3', description='Select Audio File', layout=Layout(width='500px'))

Output()

Dropdown(description='Noise Type:', index=1, layout=Layout(width='500px'), options=('Cafeteria Babble', 'Stree‚Ä¶

Dropdown(description='Speech Emotion:', layout=Layout(width='500px'), options=('Amazement', 'Anger', 'Disgust'‚Ä¶

IntSlider(value=0, continuous_update=False, description='SNR Level (dB):', layout=Layout(width='500px'), max=2‚Ä¶

Button(description='üéõÔ∏è  Process Audio', layout=Layout(height='56px', width='500px'), style=ButtonStyle(button_‚Ä¶

Output()