In [None]:
!pip install noisereduce

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import soundfile as sf
from scipy import signal
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.layers import Dense, LSTM, Dropout, Bidirectional, Input, Concatenate, Layer, Conv1D
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from IPython.display import Audio, display
import torch
import torch.nn as nn
import math
import glob
import warnings
import random
warnings.filterwarnings('ignore')
import glob
from tensorflow.keras.models import load_model
from IPython.display import Audio, display, HTML

In [None]:
np.random.seed(42)
tf.random.set_seed(42)
torch.manual_seed(42)

output_dir = "improved_hearing_aid_results"
models_dir = "improved_hearing_aid_models"
test_audio_dir = "test_audio_samples"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(test_audio_dir, exist_ok=True)

audio_files_dir = r"C:\Users\rucha\.cache\kagglehub\datasets\muhmagdy\valentini-noisy\versions\3"
train_clean_dir = os.path.join(audio_files_dir, 'clean_trainset_28spk_wav')
train_noisy_dir = os.path.join(audio_files_dir, 'noisy_trainset_28spk_wav')
test_clean_dir = os.path.join(audio_files_dir, 'clean_testset_wav')
test_noisy_dir = os.path.join(audio_files_dir, 'noisy_testset_wav')

print("Dataset directories:")
print(f"Train clean: {train_clean_dir}")
print(f"Train noisy: {train_noisy_dir}")
print(f"Test clean: {test_clean_dir}")
print(f"Test noisy: {test_noisy_dir}")

HEARING_LOSS_PROFILES = {
    'mild': [5, 10, 15, 20, 25],                Mild loss
    'moderate': [10, 20, 35, 45, 50],           Moderate loss
    'severe': [20, 35, 55, 70, 80],             Severe loss
    'high_freq': [0, 5, 15, 35, 60],            High frequency loss (most common)
    'cookie_bite': [15, 30, 40, 30, 20],        Mid-frequency loss
    'reverse_slope': [45, 35, 25, 15, 5],       Low-frequency loss (reverse slope)
    'flat': [30, 30, 30, 30, 30]                Flat loss
}

FREQ_BANDS = [0, 500, 1000, 2000, 4000, 8000]

In [None]:
class Constants:
    def __init__(self):
        self.fs = 10000  
        self.gridcoarseness = 1

def thirdoct(fs, nfft, num_bands, min_freq):
    f = np.linspace(0, fs, nfft + 1)
    f = f[:nfft//2 + 1]
    
    k = np.arange(num_bands)
    cf = min_freq * 2**(k/3)
    
    freq_low = min_freq * 2**((k-0.5)/3)
    freq_high = min_freq * 2**((k+0.5)/3)
    
    obm = np.zeros((num_bands, len(f)))
    
    for i in range(len(cf)):
        k1 = np.argmin((f - freq_low[i])**2)
        k2 = np.argmin((f - freq_high[i])**2)
        
        if k2 > k1:
            obm[i, k1:k2] = 1 / (k2 - k1)
    
    fids = []
    for i in range(num_bands):
        f_low = min_freq * 2**((i-0.5)/3)
        f_high = min_freq * 2**((i+0.5)/3)
        
        idx = np.where((f >= f_low) & (f <= f_high))[0]
        if len(idx) > 0:
            fids.append([idx[0] + 1, idx[-1] + 1])
    
    fids = np.array(fids)
    
    return obm, cf, fids, freq_low, freq_high

class MBSTOI(nn.Module):
    def __init__(self):
        super().__init__()
        self.constants = Constants()
        
    def _signal_to_frames(self, x, framelen, hop):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
            
        frames = []
        for i in range(0, x.shape[1] - framelen + 1, hop):
            frames.append(x[:, i:i+framelen])
            
        if len(frames) > 0:
            frames = torch.stack(frames, dim=1)
            window = torch.hann_window(framelen, device=x.device)
            frames = frames * window.unsqueeze(0).unsqueeze(0)
            return frames
        else:
            return torch.zeros((x.shape[0], 0, framelen), device=x.device)
    
    def _detect_silent_frames(self, x, dyn_range, framelen, hop):
        if not torch.is_tensor(x):
            x = torch.tensor(x, dtype=torch.float32)
            
        frames = self._signal_to_frames(x, framelen, hop)
        
        energy = 10 * torch.log10(torch.sum(frames**2, dim=2) + 1e-8)
        
        max_energy = torch.max(energy, dim=1, keepdim=True)[0]
        
        mask = energy > (max_energy - dyn_range)
        
        return mask, frames
            
    def _remove_silent_frames(self, xl, xr, yl, yr, dyn_range, N_frame, hop):
        if not torch.is_tensor(xl):
            xl = torch.tensor(xl, dtype=torch.float32)
            xr = torch.tensor(xr, dtype=torch.float32)
            yl = torch.tensor(yl, dtype=torch.float32)
            yr = torch.tensor(yr, dtype=torch.float32)
            
        mask_xl, frames_xl = self._detect_silent_frames(xl, dyn_range, N_frame, hop)
        mask_xr, frames_xr = self._detect_silent_frames(xr, dyn_range, N_frame, hop)
        
        mask = torch.logical_or(mask_xl, mask_xr).squeeze()
        
        frames_yl = self._signal_to_frames(yl, N_frame, hop)
        frames_yr = self._signal_to_frames(yr, N_frame, hop)
        
        frames_xl_active = frames_xl[:, mask, :]
        frames_xr_active = frames_xr[:, mask, :]
        frames_yl_active = frames_yl[:, mask, :]
        frames_yr_active = frames_yr[:, mask, :]
        
        return frames_xl_active, frames_xr_active, frames_yl_active, frames_yr_active
        
    def stoi_measure(self, x_frames, y_frames):
        if x_frames.shape[1] == 0:
            return torch.tensor(0.0, device=x_frames.device)
            
        batch_size, num_frames, frame_len = x_frames.shape
        
        corrs = []
        
        for i in range(num_frames):
            x_frame = x_frames[:, i, :]
            y_frame = y_frames[:, i, :]
            
            x_norm = x_frame - torch.mean(x_frame, dim=1, keepdim=True)
            y_norm = y_frame - torch.mean(y_frame, dim=1, keepdim=True)
            
            num = torch.sum(x_norm * y_norm, dim=1)
            denom = torch.sqrt(torch.sum(x_norm**2, dim=1) * torch.sum(y_norm**2, dim=1) + 1e-8)
            corr = num / denom
            
            corrs.append(corr)
            
        if len(corrs) > 0:
            corrs = torch.stack(corrs, dim=1)
            return torch.mean(corrs)
        else:
            return torch.tensor(0.0, device=x_frames.device)
    
    def forward(self, xl, xr, yl, yr):
        fs = 10000  
        N_frame = 256  
        hop = N_frame // 2  
        dyn_range = 40  
        
        if not torch.is_tensor(xl):
            xl = torch.tensor(xl, dtype=torch.float32)
            xr = torch.tensor(xr, dtype=torch.float32)
            yl = torch.tensor(yl, dtype=torch.float32)
            yr = torch.tensor(yr, dtype=torch.float32)
            
        xl = xl.reshape(-1)
        xr = xr.reshape(-1)
        yl = yl.reshape(-1)
        yr = yr.reshape(-1)
        
        max_len = 32000  
        if xl.shape[0] > max_len:
            xl = xl[:max_len]
            xr = xr[:max_len]
            yl = yl[:max_len]
            yr = yr[:max_len]
            
        min_len = min(xl.shape[0], xr.shape[0], yl.shape[0], yr.shape[0])
        xl = xl[:min_len].unsqueeze(0)  
        xr = xr[:min_len].unsqueeze(0)
        yl = yl[:min_len].unsqueeze(0)
        yr = yr[:min_len].unsqueeze(0)
        
        xl_frames, xr_frames, yl_frames, yr_frames = self._remove_silent_frames(
            xl, xr, yl, yr, dyn_range, N_frame, hop
        )
        
        stoi_l = self.stoi_measure(xl_frames, yl_frames)
        stoi_r = self.stoi_measure(xr_frames, yr_frames)
        
        mbstoi = torch.maximum(stoi_l, stoi_r)
        
        return mbstoi

def evaluate_mbstoi(clean_left, clean_right, processed_left, processed_right):
    mbstoi = MBSTOI()
    
    xl = torch.tensor(clean_left, dtype=torch.float32)
    xr = torch.tensor(clean_right, dtype=torch.float32)
    yl = torch.tensor(processed_left, dtype=torch.float32)
    yr = torch.tensor(processed_right, dtype=torch.float32)
    
    with torch.no_grad():  
        score = mbstoi(xl, xr, yl, yr)
        
    return score.item()  

In [None]:
def apply_filter_safely(audio, sr, low_freq, high_freq):
    if low_freq <= 0:  
        low_freq = None
        high_norm = min(high_freq / (sr/2), 0.99)
        ftype = 'lowpass'
        freq = high_norm
    elif high_freq >= sr/2:  
        high_freq = None
        low_norm = min(low_freq / (sr/2), 0.99)
        ftype = 'highpass'
        freq = low_norm
    else:  
        low_norm = min(low_freq / (sr/2), 0.99)
        high_norm = min(high_freq / (sr/2), 0.99)
        ftype = 'bandpass'
        freq = [low_norm, high_norm]
    
    try:
        b, a = signal.butter(2, freq, btype=ftype)
        filtered = signal.filtfilt(b, a, audio)
        if not np.all(np.isfinite(filtered)):
            filtered = np.nan_to_num(filtered)
        return filtered
    except Exception as e:
        print(f"Filtering error: {e} - Returning zeros")
        return np.zeros_like(audio)

def extract_band_energies(audio, sr):
    band_signals = []
    band_energies = []
    
    for i in range(len(FREQ_BANDS)-1):
        low_freq = FREQ_BANDS[i]
        high_freq = FREQ_BANDS[i+1]
        
        band_signal = apply_filter_safely(audio, sr, low_freq, high_freq)
        band_signals.append(band_signal)
        
        energy = np.sqrt(np.mean(band_signal**2))
        band_energies.append(energy)
    
    return np.array(band_signals), np.array(band_energies)

def apply_hearing_loss(audio, sr, loss_profile='high_freq'):
    if isinstance(loss_profile, str):
        loss_db = HEARING_LOSS_PROFILES[loss_profile]
    else:
        loss_db = loss_profile  
    
     Split audio into frequency bands
    band_signals = []
    for i in range(len(FREQ_BANDS)-1):
        low_freq = FREQ_BANDS[i]
        high_freq = FREQ_BANDS[i+1]
        
        band_signal = apply_filter_safely(audio, sr, low_freq, high_freq)
        band_signals.append(band_signal)
    
    attenuated_bands = []
    for i, band in enumerate(band_signals):
        attenuation = 10 ** (-loss_db[i] / 20)
        attenuated_bands.append(band * attenuation)
    
    output = np.sum(attenuated_bands, axis=0)
    
    max_amp = np.max(np.abs(output))
    if max_amp > 0.95:
        output = output * (0.95 / max_amp)
    
    return output

In [None]:
def add_noise(audio, snr_range=(5, 15), noise_type='white'):

    if noise_type == 'white':
        noise = np.random.normal(0, 1, size=audio.shape)
    elif noise_type == 'pink':
        white_noise = np.random.normal(0, 1, size=audio.shape)
        b = [0.049922035, -0.095993537, 0.050612699, -0.004408786]
        a = [1, -2.494956002, 2.017265875, -0.522189400]
        noise = signal.lfilter(b, a, white_noise)
    elif noise_type == 'speech_shaped':
        white_noise = np.random.normal(0, 1, size=audio.shape)
        lpc_order = 12
        try:
            a = signal.lpc(audio, lpc_order)[0]
            noise = signal.lfilter([1], a, white_noise)
        except:
            noise = signal.lfilter([1], [1, -0.95], white_noise)
    elif noise_type == 'babble':
        noise = np.zeros_like(audio)
        for i in range(6):   Combine 6 shifted copies
            shift = np.random.randint(1000, 10000)
            noise += np.roll(audio, shift) * np.random.uniform(0.1, 0.3)
        noise = noise - np.mean(noise * audio) * audio / np.mean(audio**2)
    else:
        noise = np.random.normal(0, 1, size=audio.shape)
    
    signal_power = np.mean(audio**2)
    noise_power = np.mean(noise**2)
    
    snr = np.random.uniform(snr_range[0], snr_range[1])
    noise_scale = np.sqrt(signal_power / (noise_power * 10**(snr/10)))
    
    noisy_audio = audio + noise * noise_scale
    
    max_amp = np.max(np.abs(noisy_audio))
    if max_amp > 0.95:
        noisy_audio = noisy_audio * (0.95 / max_amp)
    
    return noisy_audio

def add_reverberation(audio, sr, rt60_range=(0.2, 0.8)):

    rt60 = np.random.uniform(rt60_range[0], rt60_range[1])
    
    n_samples = int(rt60 * sr)
    decay = np.exp(-6.91 * np.arange(n_samples) / n_samples)
    
    impulse_response = np.random.randn(n_samples) * decay
    impulse_response[0] = 1.0  
    
    impulse_response = impulse_response / np.sum(np.abs(impulse_response))
    
    reverb_audio = signal.convolve(audio, impulse_response)[:len(audio)]
    
    max_amp = np.max(np.abs(reverb_audio))
    if max_amp > 0.95:
        reverb_audio = reverb_audio * (0.95 / max_amp)
    
    return reverb_audio

def apply_channel_effects(audio, sr):

    band_signals, _ = extract_band_energies(audio, sr)
    eq_gains = np.random.uniform(0.7, 1.3, size=len(band_signals))
    
    eq_audio = np.zeros_like(audio)
    for i, band in enumerate(band_signals):
        eq_audio += band * eq_gains[i]
    
    max_amp = np.max(np.abs(eq_audio))
    if max_amp > 0.95:
        eq_audio = eq_audio * (0.95 / max_amp)
    
    return eq_audio

def create_challenging_audio(clean_audio, sr):

    effects = np.random.choice([
        'noise', 'reverb', 'noise+reverb', 'channel', 'all'
    ])
    
    audio = clean_audio.copy()
    
    if 'noise' in effects:
        noise_type = np.random.choice(['white', 'pink', 'speech_shaped', 'babble'])
        audio = add_noise(audio, snr_range=(3, 10), noise_type=noise_type)
    
    if 'reverb' in effects:
        audio = add_reverberation(audio, sr, rt60_range=(0.3, 1.0))
    
    if 'channel' in effects or 'all' in effects:
        audio = apply_channel_effects(audio, sr)
    
    return audio

In [None]:
def prepare_diverse_binaural_data(file_list, file_dir, hearing_profiles=None, 
                                 max_files=20, challenging_audio=True):

    if hearing_profiles is None:
        hearing_profiles = list(HEARING_LOSS_PROFILES.keys())
    
    X_left = []  
    X_right = []  
    y_left = []  
    y_right = []  
    
    sample_audios = []
    
    print(f"Processing files with diverse hearing profiles and conditions...")
    
    for i, file_name in enumerate(file_list):
        if i >= max_files:
            break
            
        if i % 5 == 0:
            print(f"Processing file {i+1}/{min(len(file_list), max_files)}")
        
        try:
            file_path = os.path.join(file_dir, file_name)
            audio, sr = sf.read(file_path)
            
            if not np.all(np.isfinite(audio)):
                print(f"Warning: File {file_name} contains invalid values. Skipping.")
                continue
            
            for profile in hearing_profiles:
                profile_losses = HEARING_LOSS_PROFILES[profile]
                profile_gains = [1 + (loss/20) for loss in profile_losses]
                
                audio_left = audio.copy()
                audio_right = np.roll(audio, 4) * 0.8  
                
                if challenging_audio:
                    if np.random.random() < 0.7:  
                        audio_left = create_challenging_audio(audio_left, sr)
                        audio_right = create_challenging_audio(audio_right, sr)
                
                audio_left_with_loss = apply_hearing_loss(audio_left, sr, profile)
                audio_right_with_loss = apply_hearing_loss(audio_right, sr, profile)
                
                if len(sample_audios) < 10 and i < 10:
                    max_len = min(len(audio_left), 3 * sr)
                    sample_audios.append({
                        'clean_left': audio_left[:max_len],
                        'clean_right': audio_right[:max_len],
                        'noisy_left': audio_left_with_loss[:max_len],
                        'noisy_right': audio_right_with_loss[:max_len],
                        'sr': sr,
                        'profile': profile
                    })
                
                frame_length = int(0.03 * sr)  
                hop_length = int(0.015 * sr)   
                
                for start in range(0, len(audio) - frame_length, hop_length):
                    left_frame = audio_left_with_loss[start:start+frame_length]
                    right_frame = audio_right_with_loss[start:start+frame_length]
                    
                    _, left_frame_energies = extract_band_energies(left_frame, sr)
                    _, right_frame_energies = extract_band_energies(right_frame, sr)
                    
                    if (len(left_frame_energies) == len(profile_gains) and 
                        len(right_frame_energies) == len(profile_gains)):
                        X_left.append(left_frame_energies)
                        X_right.append(right_frame_energies)
                        y_left.append(profile_gains)
                        y_right.append(profile_gains)
                    
                    if len(X_left) % 200 == 0 and len(X_left) > 0:
                        break
        
        except Exception as e:
            print(f"Error processing file {file_name}: {e}")
            continue
    
    if len(X_left) == 0:
        raise ValueError("No valid data extracted from files.")
    
    X_left = np.array(X_left)
    X_right = np.array(X_right)
    y_left = np.array(y_left)
    y_right = np.array(y_right)
    
    X_left = X_left.reshape(X_left.shape[0], 1, X_left.shape[1])
    X_right = X_right.reshape(X_right.shape[0], 1, X_right.shape[1])
    
    print(f"Prepared {X_left.shape[0]} diverse binaural examples with {X_left.shape[2]} features")
    return X_left, X_right, y_left, y_right, sample_audios

def enhance_audio_frame(audio_frame_left, audio_frame_right, sr, model):

    left_bands, left_energies = extract_band_energies(audio_frame_left, sr)
    right_bands, right_energies = extract_band_energies(audio_frame_right, sr)
    
    left_input = left_energies.reshape(1, 1, -1)
    right_input = right_energies.reshape(1, 1, -1)
    
    left_gains, right_gains = model.predict([left_input, right_input], verbose=0)
    left_gains = left_gains[0]
    right_gains = right_gains[0]
    
    left_enhanced = np.zeros_like(audio_frame_left)
    right_enhanced = np.zeros_like(audio_frame_right)
    
    for i in range(len(FREQ_BANDS)-1):
        left_enhanced += left_bands[i] * left_gains[i]
        right_enhanced += right_bands[i] * right_gains[i]
    
    return left_enhanced, right_enhanced

def enhance_binaural_audio(left_audio, right_audio, sr, model):

    enhanced_left = np.zeros_like(left_audio)
    enhanced_right = np.zeros_like(right_audio)
    
    frame_length = int(0.03 * sr)  
    hop_length = int(0.015 * sr)   
    window = np.hanning(frame_length)
    
    left_count = np.zeros_like(left_audio)
    right_count = np.zeros_like(right_audio)
    
    for start in range(0, min(len(left_audio), len(right_audio)) - frame_length, hop_length):
        left_frame = left_audio[start:start+frame_length]
        right_frame = right_audio[start:start+frame_length]
        
        left_enhanced, right_enhanced = enhance_audio_frame(left_frame, right_frame, sr, model)
        
        left_enhanced = left_enhanced * window
        right_enhanced = right_enhanced * window
        
        enhanced_left[start:start+frame_length] += left_enhanced
        enhanced_right[start:start+frame_length] += right_enhanced
        
        left_count[start:start+frame_length] += window
        right_count[start:start+frame_length] += window
    

    left_count[left_count < 0.001] = 1.0
    right_count[right_count < 0.001] = 1.0
    enhanced_left = enhanced_left / left_count
    enhanced_right = enhanced_right / right_count
    
    left_max = np.max(np.abs(enhanced_left))
    if left_max > 0.95:
        enhanced_left = enhanced_left * (0.95 / left_max)
        
    right_max = np.max(np.abs(enhanced_right))
    if right_max > 0.95:
        enhanced_right = enhanced_right * (0.95 / right_max)
    
    return enhanced_left, enhanced_right

def evaluate_model_with_mbstoi(model, sample_audio):
    if not sample_audio:
        print("No sample audio available for MBSTOI evaluation")
        return None
        
    mbstoi_scores_noisy = []
    mbstoi_scores_enhanced = []
    profiles = []
    
    for sample in sample_audio:
        clean_left = sample['clean_left']
        clean_right = sample['clean_right']
        noisy_left = sample['noisy_left']
        noisy_right = sample['noisy_right']
        sr = sample['sr']
        profile = sample.get('profile', 'unknown')
        
        enhanced_left, enhanced_right = enhance_binaural_audio(
            noisy_left, noisy_right, sr, model
        )
        
        mbstoi_noisy = evaluate_mbstoi(clean_left, clean_right, noisy_left, noisy_right)
        mbstoi_enhanced = evaluate_mbstoi(clean_left, clean_right, enhanced_left, enhanced_right)
        
        mbstoi_scores_noisy.append(mbstoi_noisy)
        mbstoi_scores_enhanced.append(mbstoi_enhanced)
        profiles.append(profile)
    
    avg_noisy = np.mean(mbstoi_scores_noisy)
    avg_enhanced = np.mean(mbstoi_scores_enhanced)
    improvement = avg_enhanced - avg_noisy
    
    print(f"MBSTOI Evaluation - Noisy: {avg_noisy:.4f}, Enhanced: {avg_enhanced:.4f}, Improvement: {improvement:.4f}")
    
    print("\nDetailed per-sample results:")
    for i in range(len(mbstoi_scores_noisy)):
        imp = mbstoi_scores_enhanced[i] - mbstoi_scores_noisy[i]
        print(f"  Sample {i+1} ({profiles[i]}): Noisy={mbstoi_scores_noisy[i]:.4f}, Enhanced={mbstoi_scores_enhanced[i]:.4f}, Imp={imp:.4f}")
    
    return {
        'noisy': avg_noisy,
        'enhanced': avg_enhanced,
        'improvement': improvement,
        'noisy_scores': mbstoi_scores_noisy,
        'enhanced_scores': mbstoi_scores_enhanced,
        'profiles': profiles,
        'individual_improvements': [e - n for e, n in zip(mbstoi_scores_enhanced, mbstoi_scores_noisy)]
    }

In [None]:
class ScaleLayer(tf.keras.layers.Layer):
    def __init__(self, scale=3.0, **kwargs):
        super(ScaleLayer, self).__init__(**kwargs)
        self.scale = scale
        
    def call(self, inputs):
        return inputs * self.scale
        
    def get_config(self):
        config = super(ScaleLayer, self).get_config()
        config.update({"scale": self.scale})
        return config

def build_binaural_model(input_shape, num_bands, model_type='lstm'):

    left_input = Input(shape=input_shape, name='left_input')
    right_input = Input(shape=input_shape, name='right_input')
    
    if model_type == 'lstm':
        shared_lstm = Bidirectional(LSTM(64, return_sequences=False))
        
        left_lstm = shared_lstm(left_input)
        left_lstm = Dropout(0.3)(left_lstm)
        
        right_lstm = shared_lstm(right_input)
        right_lstm = Dropout(0.3)(right_lstm)
        
    elif model_type == 'cnn_lstm':

        shared_conv = Conv1D(32, 3, activation='relu', padding='same')
        shared_lstm = Bidirectional(LSTM(64, return_sequences=False))
        
        left_conv = shared_conv(left_input)
        left_lstm = shared_lstm(left_conv)
        left_lstm = Dropout(0.3)(left_lstm)
        
        right_conv = shared_conv(right_input)
        right_lstm = shared_lstm(right_conv)
        right_lstm = Dropout(0.3)(right_lstm)
        
    elif model_type == 'deep_lstm':
        shared_lstm1 = Bidirectional(LSTM(64, return_sequences=True))
        shared_lstm2 = Bidirectional(LSTM(64, return_sequences=False))
        
        left_lstm1 = shared_lstm1(left_input)
        left_lstm1 = Dropout(0.3)(left_lstm1)
        left_lstm = shared_lstm2(left_lstm1)
        left_lstm = Dropout(0.3)(left_lstm)
        
        right_lstm1 = shared_lstm1(right_input)
        right_lstm1 = Dropout(0.3)(right_lstm1)
        right_lstm = shared_lstm2(right_lstm1)
        right_lstm = Dropout(0.3)(right_lstm)
    
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    combined = Concatenate()([left_lstm, right_lstm])
    shared_dense = Dense(128, activation='relu')(combined)
    shared_dense = Dropout(0.3)(shared_dense)
    
    left_dense = Dense(64, activation='relu')(shared_dense)
    right_dense = Dense(64, activation='relu')(shared_dense)
    
    left_output = Dense(num_bands, activation='sigmoid')(left_dense)
    right_output = Dense(num_bands, activation='sigmoid')(right_dense)
    
    left_scaled = ScaleLayer(scale=3.0, name='left_scaled')(left_output)
    right_scaled = ScaleLayer(scale=3.0, name='right_scaled')(right_output)
    
    model = Model(
        inputs=[left_input, right_input],
        outputs=[left_scaled, right_scaled]
    )
    
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='mse',
        metrics={'left_scaled': 'mae', 'right_scaled': 'mae'}
    )
    
    return model

class MBSTOIEvaluationCallback(Callback):
    def __init__(self, sample_audios, eval_frequency=5):
        super(MBSTOIEvaluationCallback, self).__init__()
        self.sample_audios = sample_audios
        self.eval_frequency = eval_frequency
        self.mbstoi_history = {
            'epoch': [],
            'noisy': [],
            'enhanced': [],
            'improvement': []
        }
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.eval_frequency == 0:
            print(f"\nEvaluating MBSTOI at epoch {epoch + 1}...")
            results = evaluate_model_with_mbstoi(self.model, self.sample_audios)
            
            if results:
                self.mbstoi_history['epoch'].append(epoch + 1)
                self.mbstoi_history['noisy'].append(results['noisy'])
                self.mbstoi_history['enhanced'].append(results['enhanced'])
                self.mbstoi_history['improvement'].append(results['improvement'])


In [None]:
def train_improved_hearing_aid_model(epochs=200, model_type='lstm', 
                                    hearing_profiles=None, max_files=1000):

    train_files = sorted(os.listdir(train_clean_dir))[:1000]
    random.shuffle(train_files)  
    
    X_left, X_right, y_left, y_right, sample_audios = prepare_diverse_binaural_data(
        train_files, 
        train_clean_dir, 
        hearing_profiles=hearing_profiles,
        max_files=max_files,
        challenging_audio=True
    )
    
    X_left_train, X_left_val, X_right_train, X_right_val, y_left_train, y_left_val, y_right_train, y_right_val = train_test_split(
        X_left, X_right, y_left, y_right, test_size=0.2, random_state=42
    )
    
    print(f"Training data shapes: Left input {X_left_train.shape}, Right input {X_right_train.shape}")
    
    input_shape = (X_left_train.shape[1], X_left_train.shape[2])
    num_bands = y_left_train.shape[1]
    model = build_binaural_model(input_shape, num_bands, model_type=model_type)
    model.summary()
    
    checkpoint = ModelCheckpoint(
        os.path.join(models_dir, f'binaural_{model_type}_model.h5'),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
    
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-5,
        verbose=1
    )
    
    mbstoi_eval = MBSTOIEvaluationCallback(
        sample_audios=sample_audios,
        eval_frequency=5  
    )
    
    print("Starting training...")
    history = model.fit(
        [X_left_train, X_right_train],
        [y_left_train, y_right_train],
        validation_data=([X_left_val, X_right_val], [y_left_val, y_right_val]),
        epochs=epochs,
        batch_size=32,
        callbacks=[checkpoint, early_stopping, reduce_lr, mbstoi_eval],
        verbose=1
    )
    
    history_df = pd.DataFrame({
        'loss': history.history['loss'],
        'val_loss': history.history['val_loss'],
        'left_scaled_mae': history.history['left_scaled_mae'],
        'right_scaled_mae': history.history['right_scaled_mae'],
        'val_left_scaled_mae': history.history['val_left_scaled_mae'],
        'val_right_scaled_mae': history.history['val_right_scaled_mae']
    })
    history_df.to_csv(os.path.join(output_dir, f'training_history_{model_type}.csv'), index=False)
    
    plt.figure(figsize=(16, 12))
    
    plt.subplot(2, 2, 1)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Validation')
    plt.title('Loss (MSE)')
    plt.xlabel('Epoch')
    plt.ylabel('Mean Squared Error')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.subplot(2, 2, 2)
    plt.plot(history.history['left_scaled_mae'], label='Left Ear')
    plt.plot(history.history['right_scaled_mae'], label='Right Ear')
    plt.title('Mean Absolute Error')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.subplot(2, 1, 2)
    mbstoi_epochs = mbstoi_eval.mbstoi_history['epoch']
    mbstoi_noisy = mbstoi_eval.mbstoi_history['noisy']
    mbstoi_enhanced = mbstoi_eval.mbstoi_history['enhanced']
    mbstoi_improvement = mbstoi_eval.mbstoi_history['improvement']
    
    plt.plot(mbstoi_epochs, mbstoi_noisy, 'b-o', label='Noisy')
    plt.plot(mbstoi_epochs, mbstoi_enhanced, 'g-o', label='Enhanced')
    
    for i, epoch in enumerate(mbstoi_epochs):
        improvement = mbstoi_improvement[i]
        color = 'green' if improvement > 0 else 'red'
        plt.annotate(f"{improvement:.4f}", 
                    xy=(epoch, mbstoi_enhanced[i]),
                    xytext=(0, 10),
                    textcoords="offset points",
                    ha='center', color=color, fontsize=9)
    
    plt.title('MBSTOI Intelligibility During Training')
    plt.xlabel('Epoch')
    plt.ylabel('MBSTOI Score')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'training_results_{model_type}.png'))
    plt.show()

    
    
    return model, mbstoi_eval.mbstoi_history



In [None]:
def create_test_files_with_hearing_loss():

    test_files = sorted(os.listdir(test_clean_dir))[:1000]  # Use 10 test files
    

    test_profiles = ['mild', 'moderate', 'high_freq', 'severe']
    
    test_cases = []
    
    for i, file_name in enumerate(test_files):
        try:
            file_path = os.path.join(test_clean_dir, file_name)
            audio, sr = sf.read(file_path)
            
            max_len = min(len(audio), 30 * sr)
            audio = audio[:max_len]
            
            for profile in test_profiles:
                audio_left = audio.copy()
                audio_right = np.roll(audio, 4) * 0.8
                
                if np.random.random() < 0.7:
                    condition = np.random.choice(['noise', 'reverb', 'both'])
                    
                    if condition == 'noise' or condition == 'both':
                        noise_type = np.random.choice(['white', 'pink', 'speech_shaped'])
                        snr = np.random.uniform(3, 10)
                        audio_left = add_noise(audio_left, snr_range=(snr, snr), noise_type=noise_type)
                        audio_right = add_noise(audio_right, snr_range=(snr, snr), noise_type=noise_type)
                    
                    if condition == 'reverb' or condition == 'both':
                        rt60 = np.random.uniform(0.3, 0.8)
                        audio_left = add_reverberation(audio_left, sr, rt60_range=(rt60, rt60))
                        audio_right = add_reverberation(audio_right, sr, rt60_range=(rt60, rt60))
                    
                    condition_name = condition
                else:
                    condition_name = 'clean'
                
                audio_left_with_loss = apply_hearing_loss(audio_left, sr, profile)
                audio_right_with_loss = apply_hearing_loss(audio_right, sr, profile)
                
                test_id = f"test{i+1}_{profile}_{condition_name}"
                
                test_dir = os.path.join(test_audio_dir, test_id)
                os.makedirs(test_dir, exist_ok=True)
                
                sf.write(os.path.join(test_dir, 'clean_left.wav'), audio_left, sr)
                sf.write(os.path.join(test_dir, 'clean_right.wav'), audio_right, sr)
                sf.write(os.path.join(test_dir, 'noisy_left.wav'), audio_left_with_loss, sr)
                sf.write(os.path.join(test_dir, 'noisy_right.wav'), audio_right_with_loss, sr)
                
                test_cases.append({
                    'id': test_id,
                    'profile': profile,
                    'condition': condition_name,
                    'dir': test_dir,
                    'sr': sr
                })
                
        except Exception as e:
            print(f"Error processing test file {file_name}: {e}")
    
    print(f"Created {len(test_cases)} test cases in {test_audio_dir}")
    return test_cases

def evaluate_on_test_cases(model, test_cases):

    results = {
        'mbstoi': [],
        'gains': [],
        'profiles': [],
        'conditions': []
    }
    
    for case in test_cases:
        test_id = case['id']
        profile = case['profile']
        condition = case['condition']
        test_dir = case['dir']
        sr = case['sr']
        
        print(f"\nEvaluating test case: {test_id}")
        
        clean_left, _ = sf.read(os.path.join(test_dir, 'clean_left.wav'))
        clean_right, _ = sf.read(os.path.join(test_dir, 'clean_right.wav'))
        noisy_left, _ = sf.read(os.path.join(test_dir, 'noisy_left.wav'))
        noisy_right, _ = sf.read(os.path.join(test_dir, 'noisy_right.wav'))
        
        enhanced_left, enhanced_right = enhance_binaural_audio(
            noisy_left, noisy_right, sr, model
        )
        
        sf.write(os.path.join(test_dir, 'enhanced_left.wav'), enhanced_left, sr)
        sf.write(os.path.join(test_dir, 'enhanced_right.wav'), enhanced_right, sr)
        
        mbstoi_noisy = evaluate_mbstoi(clean_left, clean_right, noisy_left, noisy_right)
        mbstoi_enhanced = evaluate_mbstoi(clean_left, clean_right, enhanced_left, enhanced_right)
        improvement = mbstoi_enhanced - mbstoi_noisy
        
        print(f"  MBSTOI - Noisy: {mbstoi_noisy:.4f}, Enhanced: {mbstoi_enhanced:.4f}, Improvement: {improvement:.4f}")
        
        _, noisy_energies = extract_band_energies(noisy_left, sr)
        _, enhanced_energies = extract_band_energies(enhanced_left, sr)
        
        noisy_db = 20 * np.log10(noisy_energies + 1e-10)
        enhanced_db = 20 * np.log10(enhanced_energies + 1e-10)
        gains_db = enhanced_db - noisy_db
        
        print("  Frequency-Specific Gains (dB):")
        print(f"  Average gain across bands: {np.mean(gains_db):.2f} dB")
        for i in range(len(FREQ_BANDS)-1):
            band = f"{FREQ_BANDS[i]}-{FREQ_BANDS[i+1]}Hz"
            print(f"  {band:>12}: {gains_db[i]:.2f} dB gain")
        
        results['mbstoi'].append({
            'test_id': test_id,
            'noisy': mbstoi_noisy,
            'enhanced': mbstoi_enhanced,
            'improvement': improvement
        })
        
        results['gains'].append({
            'test_id': test_id,
            'average': np.mean(gains_db),
            'per_band': {f"{FREQ_BANDS[i]}-{FREQ_BANDS[i+1]}Hz": gains_db[i] 
                         for i in range(len(gains_db))}
        })
        
        results['profiles'].append(profile)
        results['conditions'].append(condition)
    
    return results

def analyze_and_plot_results(results):

    profiles = results['profiles']
    conditions = results['conditions']
    mbstoi_data = results['mbstoi']
    gains_data = results['gains']
    
    mbstoi_df = pd.DataFrame({
        'Profile': profiles,
        'Condition': conditions,
        'Noisy': [d['noisy'] for d in mbstoi_data],
        'Enhanced': [d['enhanced'] for d in mbstoi_data],
        'Improvement': [d['improvement'] for d in mbstoi_data]
    })
    
    profile_summary = mbstoi_df.groupby('Profile').agg({
        'Noisy': 'mean',
        'Enhanced': 'mean',
        'Improvement': 'mean'
    }).reset_index()
    
    condition_summary = mbstoi_df.groupby('Condition').agg({
        'Noisy': 'mean',
        'Enhanced': 'mean',
        'Improvement': 'mean'
    }).reset_index()
    
    overall_summary = {
        'avg_noisy': mbstoi_df['Noisy'].mean(),
        'avg_enhanced': mbstoi_df['Enhanced'].mean(),
        'avg_improvement': mbstoi_df['Improvement'].mean(),
        'median_improvement': mbstoi_df['Improvement'].median(),
        'positive_improvement_rate': (mbstoi_df['Improvement'] > 0).mean() * 100
    }
    
    avg_gains = {}
    for band in FREQ_BANDS[:-1]:
        band_key = f"{band}-{FREQ_BANDS[FREQ_BANDS.index(band)+1]}Hz"
        avg_gains[band_key] = np.mean([g['per_band'].get(band_key, 0) for g in gains_data])
    
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 2, 1)
    x = np.arange(len(profile_summary))
    width = 0.35
    plt.bar(x - width/2, profile_summary['Noisy'], width, label='Noisy')
    plt.bar(x + width/2, profile_summary['Enhanced'], width, label='Enhanced')
    plt.xlabel('Hearing Loss Profile')
    plt.ylabel('MBSTOI Score')
    plt.title('MBSTOI by Hearing Loss Profile')
    plt.xticks(x, profile_summary['Profile'])
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.subplot(2, 2, 2)
    plt.bar(condition_summary['Condition'], condition_summary['Improvement'])
    plt.xlabel('Acoustic Condition')
    plt.ylabel('MBSTOI Improvement')
    plt.title('Improvement by Acoustic Condition')
    plt.grid(alpha=0.3)
    
    plt.subplot(2, 1, 2)
    bands = list(avg_gains.keys())
    gains = list(avg_gains.values())
    plt.bar(bands, gains)
    plt.xlabel('Frequency Band')
    plt.ylabel('Gain (dB)')
    plt.title('Average Frequency-Specific Gains')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'comprehensive_evaluation.png'))
    plt.show()
    
    print("\nComprehensive Evaluation Results:")
    print(f"Overall MBSTOI - Noisy: {overall_summary['avg_noisy']:.4f}, Enhanced: {overall_summary['avg_enhanced']:.4f}")
    print(f"Average improvement: {overall_summary['avg_improvement']:.4f}")
    print(f"Positive improvement rate: {overall_summary['positive_improvement_rate']:.1f}%")
    
    print("\nMBSTOI by Hearing Loss Profile:")
    for _, row in profile_summary.iterrows():
        print(f"  {row['Profile']:>10}: {row['Improvement']:.4f}")
    
    print("\nMBSTOI by Acoustic Condition:")
    for _, row in condition_summary.iterrows():
        print(f"  {row['Condition']:>10}: {row['Improvement']:.4f}")
    
    print("\nAverage Frequency-Specific Gains (dB):")
    print(f"Average gain across bands: {np.mean(list(avg_gains.values())):.2f} dB")
    for band, gain in avg_gains.items():
        print(f"  {band:>12}: {gain:.2f} dB gain")
    
    return {
        'overall': overall_summary,
        'by_profile': profile_summary,
        'by_condition': condition_summary,
        'avg_gains': avg_gains
    }

In [None]:
def main():
    model, mbstoi_history = train_improved_hearing_aid_model(
        epochs=200,
        model_type='cnn_lstm',  
        hearing_profiles=['high_freq', 'moderate', 'mild'],  
        max_files=1000 
    )
    
    test_cases = create_test_files_with_hearing_loss()
    
    results = evaluate_on_test_cases(model, test_cases)
    
    summary = analyze_and_plot_results(results)
    
    print("\nImproved hearing aid model training and evaluation complete!")
    print(f"Results saved to {output_dir}")

if __name__ == "__main__":
    main()

In [None]:
class ScaleLayer(tf.keras.layers.Layer):
    def __init__(self, scale=3.0, **kwargs):
        super().__init__(**kwargs)
        self.scale = scale
    def call(self, inputs):
        return inputs * self.scale
    def get_config(self):
        cfg = super().get_config()
        cfg.update({"scale": self.scale})
        return cfg

models_dir = "improved_hearing_aid_models"
model_path = os.path.join(models_dir, 'binaural_cnn_lstm_model.h5')
model = load_model(
    model_path,
    custom_objects={'ScaleLayer': ScaleLayer},
    compile=False
)


from scipy import signal
FREQ_BANDS = [0, 500, 1000, 2000, 4000, 8000]
def enhance_binaural_audio(left, right, sr, model):
    win = int(0.03 * sr)
    hop = win // 2
    window = np.hanning(win)
    outL = np.zeros_like(left); cntL = np.zeros_like(left)
    outR = np.zeros_like(right); cntR = np.zeros_like(right)

    def extract_feats(x):
        bands, en = [], []
        for lo, hi in zip(FREQ_BANDS, FREQ_BANDS[1:]):
            if lo <= 0:
                b, a = signal.butter(2, hi/(sr/2), btype='low')
            elif hi >= sr/2:
                b, a = signal.butter(2, lo/(sr/2), btype='high')
            else:
                b, a = signal.butter(2, [lo/(sr/2), hi/(sr/2)], btype='band')
            band = signal.filtfilt(b, a, x)
            bands.append(band)
            en.append(np.sqrt(np.mean(band**2)))
        return np.stack(bands), np.array(en)

    for i in range(0, len(left) - win, hop):
        lseg = left[i:i+win]; rseg = right[i:i+win]
        lb, lE = extract_feats(lseg)
        rb, rE = extract_feats(rseg)
        gL, gR = model.predict([lE.reshape(1,1,-1), rE.reshape(1,1,-1)], verbose=0)
        enhL = (lb * gL[0][:,None]).sum(axis=0) * window
        enhR = (rb * gR[0][:,None]).sum(axis=0) * window
        outL[i:i+win] += enhL; cntL[i:i+win] += window
        outR[i:i+win] += enhR; cntR[i:i+win] += window

    cntL[cntL<1e-3] = 1; cntR[cntR<1e-3] = 1
    outL /= cntL; outR /= cntR
    for sig in (outL, outR):
        pk = np.max(np.abs(sig))
        if pk > 0.95: sig *= 0.95/pk
    return outL, outR


input_dir  = r"C:\Users\rucha\OneDrive\Desktop\Mac files\FCE\Project\testing noisy files"
output_dir = "improved_hearing_aid_results"
os.makedirs(output_dir, exist_ok=True)

def rms(x):
    return np.sqrt(np.mean(x**2))

for wav in glob.glob(os.path.join(input_dir, "*.wav")):
    fname = os.path.basename(wav)
    print(f"\n▶ Processing {fname}")

    # read & mono
    audio, sr = sf.read(wav)
    if audio.ndim == 2:
        audio = audio.mean(axis=1)

    # binaural inputs
    left  = audio.copy()
    right = np.roll(audio, 4) * 0.8

    # enhance
    L, R = enhance_binaural_audio(left, right, sr, model)
    enhanced = (L + R) / 2

    # save
    out_path = os.path.join(output_dir, f"enh_{fname}")
    sf.write(out_path, enhanced, sr)
    
    # metrics
    r0 = rms(audio)
    r1 = rms(enhanced)
    gain_db = 20 * np.log10((r1 + 1e-12)/(r0 + 1e-12))

    print(f"  RMS (Noisy):    {r0:.4f}")
    print(f"  RMS (Enhanced): {r1:.4f}")
    print(f"  Gain:           {gain_db:+.2f} dB")


In [None]:
try:
    import noisereduce as nr
except ImportError:
    nr = None
    display(HTML("<p style='color:red;'> noisereduce not installed; install with <code>!pip install noisereduce</code> to enable denoising.</p>"))

def load_and_norm(path):
    audio, sr = sf.read(path)
    if audio.ndim == 2:
        audio = audio.mean(axis=1)
    peak = np.max(np.abs(audio)) or 1.0
    return audio / peak, sr

input_dir = r"C:\Users\rucha\OneDrive\Desktop\Mac files\FCE\Project\testing noisy files"
enh_dir   = "improved_hearing_aid_results"

for wav_path in glob.glob(os.path.join(input_dir, "*.wav")):
    fname = os.path.basename(wav_path)
    display(HTML(f"<hr><h2>File: {fname}</h2>"))

    noisy, sr = load_and_norm(wav_path)
    display(HTML("<h4>Original Noisy</h4>"))
    display(Audio(noisy, rate=sr, normalize=False))

    enh_path = os.path.join(enh_dir, f"enh_{fname}")
    if not os.path.exists(enh_path):
        display(HTML(f"<p style='color:red;'>Enhanced file not found: {enh_path}</p>"))
        continue

    enhanced, _ = load_and_norm(enh_path)

    if nr:
        denh = nr.reduce_noise(y=enhanced, sr=sr)
        denh /= (np.max(np.abs(denh)) or 1.0)
        display(HTML("<h4>Denoised Enhanced</h4>"))
        display(Audio(denh, rate=sr, normalize=False))
    else:
        display(HTML("<p style='color:red;'>⚠️ noisereduce not available, playing raw enhanced.</p>"))
        display(HTML("<h4>Enhanced (raw)</h4>"))
        display(Audio(enhanced, rate=sr, normalize=False))


In [None]:
try:
    import noisereduce as nr
except ImportError:
    nr = None
    display(HTML(
      "<p style='color:red;'>⚠️ `noisereduce` not installed; "
      "install with <code>!pip install noisereduce</code> to enable denoising.</p>"
    ))

def load_mono(path):
    audio, sr = sf.read(path)
    if audio.ndim == 2:
        audio = audio.mean(axis=1)
    return audio, sr

input_dir = r"C:\Users\rucha\OneDrive\Desktop\Mac files\FCE\Project\testing noisy files"
enh_dir   = "improved_hearing_aid_results"

for wav_path in glob.glob(os.path.join(input_dir, "*.wav")):
    fname = os.path.basename(wav_path)
    display(HTML(f"<hr><h2>File: {fname}</h2>"))

    noisy, sr = load_mono(wav_path)

    enh_path = os.path.join(enh_dir, f"enh_{fname}")
    if not os.path.exists(enh_path):
        display(HTML(f"<p style='color:red;'>Enhanced file not found: {enh_path}</p>"))
        continue
    enhanced, _ = load_mono(enh_path)

    if nr:
        denoised_enh = nr.reduce_noise(y=enhanced, sr=sr)
    else:
        denoised_enh = enhanced

    items = [
        ("Original Noisy", noisy),
        ("LSTM‑Enhanced", enhanced),
        ("Denoised Enhanced", denoised_enh)
    ]

    fig, axes = plt.subplots(len(items), 2, figsize=(12, 4 * len(items)))
    for i, (label, sig) in enumerate(items):
        t = np.arange(len(sig)) / sr

        # waveform
        ax_w = axes[i, 0]
        ax_w.plot(t, sig, linewidth=0.5)
        ax_w.set_title(f"{label} Waveform")
        ax_w.set_xlabel("Time [s]")
        ax_w.set_ylabel("Amplitude")

        # spectrogram
        ax_s = axes[i, 1]
        ax_s.specgram(sig, NFFT=512, Fs=sr, noverlap=256)
        ax_s.set_title(f"{label} Spectrogram")
        ax_s.set_xlabel("Time [s]")
        ax_s.set_ylabel("Freq [Hz]")

    plt.tight_layout()
    plt.show()

In [None]:

HEARING_LOSS_PROFILES = {
    'mild': [5, 10, 15, 20, 25],               
    'moderate': [10, 20, 35, 45, 50],          
    'severe': [20, 35, 55, 70, 80],            
    'high_freq': [0, 5, 15, 35, 60],          
    'cookie_bite': [15, 30, 40, 30, 20],       
    'reverse_slope': [45, 35, 25, 15, 5],      
    'flat': [30, 30, 30, 30, 30]              
}

FREQ_BANDS = [0, 500, 1000, 2000, 4000, 8000]
x_labels = [f"{FREQ_BANDS[i]}" for i in range(len(FREQ_BANDS)-1)]

plt.figure(figsize=(10, 6))

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2']

for i, (profile, loss_values) in enumerate(HEARING_LOSS_PROFILES.items()):
    plt.plot(range(len(loss_values)), loss_values, 'o-', 
             linewidth=2.5, label=profile.replace('_', ' ').title(),
             color=colors[i], markersize=8)

plt.gca().invert_yaxis()

plt.axhspan(0, 25, alpha=0.1, color='green')
plt.axhspan(25, 40, alpha=0.1, color='yellow')
plt.axhspan(40, 70, alpha=0.1, color='orange')
plt.axhspan(70, 90, alpha=0.1, color='red')

plt.text(4.5, 12.5, 'MILD', fontsize=9, ha='right', color='green', fontweight='bold')
plt.text(4.5, 32.5, 'MODERATE', fontsize=9, ha='right', color='#b5b500', fontweight='bold')
plt.text(4.5, 55, 'SEVERE', fontsize=9, ha='right', color='#b27300', fontweight='bold')
plt.text(4.5, 75, 'PROFOUND', fontsize=9, ha='right', color='#b30000', fontweight='bold')

plt.xticks(range(len(x_labels)), x_labels)
plt.xlabel('Frequency (Hz)', fontsize=12, fontweight='bold')
plt.ylabel('Hearing Loss (dB)', fontsize=12, fontweight='bold')
plt.title('Hearing Loss Profiles Across Frequency Bands', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, linestyle='--')
plt.legend(title='Profiles', title_fontsize=12, fontsize=10, loc='lower left')

plt.annotate('High-frequency loss is most common,\nwith greatest impact in 4000-8000Hz range', 
             xy=(4, 60), xytext=(2.5, 35),
             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
             fontsize=10, ha='center', bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", alpha=0.8))

plt.tight_layout()
plt.savefig('hearing_loss_profiles.png', dpi=300)
plt.show()