In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# =============================================
# COMPLETE RSUNet SPEECH DENOISING PIPELINE
# (Corrected SNR Mixing + Training + Denoising)
# =============================================

# ---
# NOTE: This code has been cleaned of non-printable
# characters (like U+00A0) that cause SyntaxErrors
# after copy-pasting.
# ---

import os
import re
import numpy as np
import librosa
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
import IPython.display as ipd

# -------------------------------
# 1Ô∏è‚É£ PATHS
# -------------------------------
CLEAN_DIR = "/kaggle/input/datasetnew/clean_wav/clean_wav"
NOISE_DIR = "/kaggle/input/datasetnew/noise_wav/noise_wav"
MIX_DIR   = "/kaggle/working/mixed_snr_wav"
STFT_DIR  = "/kaggle/working/dataset"
ENHANCED_DIR = "/kaggle/working/enhanced_audio"

for d in [MIX_DIR, STFT_DIR, ENHANCED_DIR]:
    os.makedirs(d, exist_ok=True)
    os.makedirs(os.path.join(STFT_DIR,"noisy_mag"), exist_ok=True)
    os.makedirs(os.path.join(STFT_DIR,"clean_mag"), exist_ok=True)
    os.makedirs(os.path.join(STFT_DIR,"phase"), exist_ok=True)

# -------------------------------
# 2Ô∏è‚É£ SNR MIXING FUNCTIONS (CORRECTED)
# -------------------------------
def calculate_rms(audio):
    return np.sqrt(np.mean(audio**2))

def adjust_noise_to_snr(clean, noise, target_snr_db):
    clean_rms = calculate_rms(clean)
    noise_rms = calculate_rms(noise)
    snr_linear = 10 ** (target_snr_db / 10)
    target_noise_rms = clean_rms / (np.sqrt(snr_linear) + 1e-8)
    scaling_factor = target_noise_rms / (noise_rms + 1e-8)
    return noise * scaling_factor

def mix_at_snr(clean, noise, target_snr_db):
    min_len = min(len(clean), len(noise))
    clean, noise = clean[:min_len], noise[:min_len]
    
    # Normalize clean audio
    clean = clean / (np.max(np.abs(clean)) + 1e-8)
    
    # Scale noise to target SNR
    scaled_noise = adjust_noise_to_snr(clean, noise, target_snr_db)
    
    # Mix
    mixed = clean + scaled_noise
    
    # --- ‚ö†Ô∏è CRITICAL FIX HERE ‚ö†Ô∏è ---
    # DO NOT re-normalize the mixed signal. This breaks the
    # mathematical relationship to the clean signal.
    # Instead, clip to prevent distortion if it exceeds 1.0.
    mixed = np.clip(mixed, -1.0, 1.0)
    # --- END FIX ---
    
    actual_snr = 10*np.log10(np.mean(clean**2)/(np.mean(scaled_noise**2)+1e-8))
    return mixed, scaled_noise, actual_snr

def prepare_snr_mixes(clean_dir, noise_dir, mix_dir, snr_levels=[-5,0,5,10], sr=16000):
    clean_files = sorted([f for f in os.listdir(clean_dir) if f.endswith(".wav")])
    noise_files = sorted([f for f in os.listdir(noise_dir) if f.endswith(".wav")])
    print(f"Processing {len(clean_files)} clean files and {len(noise_files)} noise files")

    for clean_file in tqdm(clean_files, desc="Mixing audio"):
        clean_path = os.path.join(clean_dir, clean_file)
        clean_audio, _ = librosa.load(clean_path, sr=sr)
        for noise_file in noise_files:
            noise_path = os.path.join(noise_dir, noise_file)
            noise_audio, _ = librosa.load(noise_path, sr=sr)
            for snr_db in snr_levels:
                mixed_audio, scaled_noise, actual_snr = mix_at_snr(clean_audio, noise_audio, snr_db)
                mix_name = f"{clean_file[:-4]}_{noise_file[:-4]}_snr{snr_db}dB.wav"
                sf.write(os.path.join(mix_dir, mix_name), mixed_audio, sr)
                if clean_files.index(clean_file)<2 and noise_files.index(noise_file)<2:
                    print(f"Target SNR: {snr_db} dB, Actual SNR: {actual_snr:.2f} dB")
    print(f"Mixed files created: {len(os.listdir(mix_dir))}")

# Regenerate correct SNR mixtures
# (Remember to delete old MIX_DIR and STFT_DIR first)
prepare_snr_mixes(CLEAN_DIR, NOISE_DIR, MIX_DIR)

# -------------------------------
# 3Ô∏è‚É£ STFT DATASET PREPARATION (NOW CORRECT)
# -------------------------------
def prepare_stft_dataset(clean_dir, mixed_dir, output_root, sr=16000, n_fft=1024, hop_length=256):
    mixed_files = [f for f in os.listdir(mixed_dir) if f.endswith(".wav")]
    for f in tqdm(mixed_files, desc="Creating STFT"):
        clean_id = f.split("_")[0]
        clean_path = os.path.join(clean_dir, clean_id+".wav")
        if not os.path.exists(clean_path):
            continue
            
        # Load the (now correctly-mixed) noisy file
        noisy, _ = librosa.load(os.path.join(mixed_dir,f), sr=sr)
        
        # Load the original clean file
        clean_orig, _ = librosa.load(clean_path, sr=sr)
        
        min_len = min(len(noisy), len(clean_orig))
        noisy = noisy[:min_len]
        clean_orig = clean_orig[:min_len]
        
        # --- ‚ö†Ô∏è ALIGNMENT FIX ---
        # We must create the clean_mag from the *same normalized clean signal*
        # that was used in the mixing process.
        clean_normalized = clean_orig / (np.max(np.abs(clean_orig)) + 1e-8)
        # --- END ALIGNMENT FIX ---

        noisy_stft = librosa.stft(noisy, n_fft=n_fft, hop_length=hop_length)
        
        # Use the normalized clean signal for the target
        clean_stft = librosa.stft(clean_normalized, n_fft=n_fft, hop_length=hop_length)
        
        noisy_mag, clean_mag, phase = np.abs(noisy_stft), np.abs(clean_stft), np.angle(noisy_stft)
        base_name = f.replace(".wav","")
        np.save(os.path.join(output_root,"noisy_mag",f"{base_name}_noisy_mag.npy"), noisy_mag)
        np.save(os.path.join(output_root,"clean_mag",f"{base_name}_clean_mag.npy"), clean_mag)
        np.save(os.path.join(output_root,"phase",f"{base_name}_phase.npy"), phase)

prepare_stft_dataset(CLEAN_DIR, MIX_DIR, STFT_DIR)

# -------------------------------
# 4Ô∏è‚É£ DATASET CLASS
# -------------------------------
class SpeechDataset(Dataset):
    def __init__(self, noisy_dir, clean_dir, max_frames=256):
        self.noisy_files = sorted([f for f in os.listdir(noisy_dir) if f.endswith(".npy")])
        self.clean_files = sorted([f for f in os.listdir(clean_dir) if f.endswith(".npy")])
        self.noisy_dir, self.clean_dir, self.max_frames = noisy_dir, clean_dir, max_frames

    def pad_or_crop(self, x):
        freq,time = x.shape
        if time < self.max_frames:
            pad = np.zeros((freq,self.max_frames-time))
            x = np.concatenate([x,pad],axis=1)
        else:
            x = x[:,:self.max_frames]
        return x

    def __len__(self):
        return len(self.noisy_files)

    def __getitem__(self, idx):
        noisy = np.load(os.path.join(self.noisy_dir,self.noisy_files[idx]))
        clean = np.load(os.path.join(self.clean_dir,self.clean_files[idx]))
        noisy, clean = self.pad_or_crop(noisy), self.pad_or_crop(clean)
        return torch.tensor(noisy).unsqueeze(0).float(), torch.tensor(clean).unsqueeze(0).float()

noisy_dir_stft = os.path.join(STFT_DIR,"noisy_mag")
clean_dir_stft = os.path.join(STFT_DIR,"clean_mag")
dataset = SpeechDataset(noisy_dir_stft, clean_dir_stft, max_frames=256)
train_len = int(0.8*len(dataset))
val_len   = int(0.1*len(dataset))
test_len  = len(dataset)-train_len-val_len
train_ds, val_ds, test_ds = random_split(dataset,[train_len,val_len,test_len])
train_dl = DataLoader(train_ds,batch_size=8,shuffle=True)
val_dl   = DataLoader(val_ds,batch_size=8)
test_dl  = DataLoader(test_ds,batch_size=8)
print(f"Dataset sizes ‚Äî Train: {train_len}, Val: {val_len}, Test: {test_len}")

# -------------------------------
# 5Ô∏è‚É£ RSUNET MODEL
# -------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.skip = nn.Conv2d(in_c,out_c,1) if in_c!=out_c else nn.Identity()
    def forward(self,x):
        residual = self.skip(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x+residual)

class RSUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = ResidualBlock(1,32)
        self.enc2 = ResidualBlock(32,64)
        self.enc3 = ResidualBlock(64,128)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ResidualBlock(128,256)
        self.up3 = nn.ConvTranspose2d(256,128,2,stride=2)
        self.dec3 = ResidualBlock(256,128)
        self.up2 = nn.ConvTranspose2d(128,64,2,stride=2)
        self.dec2 = ResidualBlock(128,64)
        self.up1 = nn.ConvTranspose2d(64,32,2,stride=2)
        self.dec1 = ResidualBlock(64,32)
        self.out_conv = nn.Conv2d(32,1,1)

    def crop_or_pad(self, src, target):
        diffY = target.size(2)-src.size(2)
        diffX = target.size(3)-src.size(3)
        return F.pad(src, [diffX//2,diffX-diffX//2,diffY//2,diffY-diffY//2])

    def forward(self,x):
        e1=self.enc1(x)
        e2=self.enc2(self.pool(e1))
        e3=self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.up3(b); d3=self.crop_or_pad(d3,e3); d3=torch.cat([d3,e3],1); d3=self.dec3(d3)
        d2 = self.up2(d3); d2=self.crop_or_pad(d2,e2); d2=torch.cat([d2,e2],1); d2=self.dec2(d2)
        d1 = self.up1(d2); d1=self.crop_or_pad(d1,e1); d1=torch.cat([d1,e1],1); d1=self.dec1(d1)
        out = self.out_conv(d1)
        return out

# -------------------------------
# 6Ô∏è‚É£ TRAINING FUNCTION
# -------------------------------
def train_model(model, train_dl, val_dl, name, epochs=30, lr=1e-3, patience=7):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_val_loss=float('inf'); patience_counter=0
    train_losses,val_losses=[],[]
    
    for epoch in range(epochs):
        model.train(); train_loss=0
        for x,y in train_dl:
            x,y=x.to(device),y.to(device)
            optimizer.zero_grad()
            loss=criterion(model(x),y)
            loss.backward(); optimizer.step()
            train_loss+=loss.item()
        train_loss/=len(train_dl)

        model.eval(); val_loss=0
        with torch.no_grad():
            for x,y in val_dl:
                x,y=x.to(device),y.to(device)
                val_loss+=criterion(model(x),y).item()
        val_loss/=len(val_dl)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if val_loss<best_val_loss:
            best_val_loss=val_loss
            patience_counter=0
            torch.save({'model_state_dict':model.state_dict()}, f"/kaggle/working/best_{name}.pth")
        else:
            patience_counter+=1
        if patience_counter>=patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
        if (epoch+1)%10==0:
            print(f"Epoch {epoch+1}: Train Loss {train_loss:.6f}, Val Loss {val_loss:.6f}")
    
    plt.figure(figsize=(10,5))
    plt.plot(train_losses,label='Train'); plt.plot(val_losses,label='Val')
    plt.title(f'{name} Training'); plt.xlabel('Epoch'); plt.ylabel('MSE Loss'); plt.legend(); plt.grid(True)
    plt.show()
    return model

# -------------------------------
# 7Ô∏è‚É£ SNR CALCULATION (NOW CORRECT)
# -------------------------------
def calculate_snr_proper(clean,noisy):
    # This function now works because 'clean' and 'noisy'
    # are derived from the same aligned, normalized source.
    min_len=min(len(clean),len(noisy))
    clean,noisy=clean[:min_len],noisy[:min_len]
    noise=noisy-clean
    return 10*np.log10(np.mean(clean**2)/(np.mean(noise**2)+1e-8))

# -------------------------------
# 8Ô∏è‚É£ AUDIO RECONSTRUCTION
# -------------------------------
def reconstruct_audio(model,noisy_mag_path,phase_path,output_dir,sr=16000,max_frames=256):
    os.makedirs(output_dir,exist_ok=True)
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model=model.to(device); model.eval()

    noisy_mag = np.load(noisy_mag_path)
    phase = np.load(phase_path)
    freq,time=noisy_mag.shape
    if time<max_frames:
        noisy_mag_input=np.concatenate([noisy_mag,np.zeros((freq,max_frames-time))],axis=1)
    else:
        noisy_mag_input=noisy_mag[:,:max_frames]

    noisy_tensor=torch.tensor(noisy_mag_input).unsqueeze(0).unsqueeze(0).float().to(device)
    with torch.no_grad():
        enhanced_mag=model(noisy_tensor).cpu().squeeze().numpy()

    freq_mag,time_mag = enhanced_mag.shape
    freq_phase,time_phase=phase.shape
    if time_phase<time_mag:
        phase_aligned=np.concatenate([phase,np.zeros((freq_phase,time_mag-time_phase))],axis=1)
    else:
        phase_aligned=phase[:,:time_mag]

    enhanced_stft=enhanced_mag*np.exp(1j*phase_aligned)
    enhanced_audio=librosa.istft(enhanced_stft)

    output_path=os.path.join(output_dir, os.path.basename(noisy_mag_path).replace("_noisy_mag.npy","_enhanced.wav"))
    sf.write(output_path, enhanced_audio, sr)
    print(f"Saved enhanced audio: {output_path}")
    return output_path

# -------------------------------
# 9Ô∏è‚É£ TRAIN + DENOSING EXECUTION (CORRECTED)
# -------------------------------
print("Training RSUNet on corrected dataset...")
rsunet = RSUNet()
trained_model = train_model(rsunet, train_dl, val_dl, "rsunet_fixed", epochs=30)

# Evaluate first 3 samples
files = sorted(os.listdir(noisy_dir_stft))[:3]
for idx,f in enumerate(files):
    clean_id = f.split("_")[0]
    noisy_path = os.path.join(noisy_dir_stft,f)
    phase_path = os.path.join(STFT_DIR,"phase",f.replace("_noisy_mag.npy","_phase.npy"))
    clean_wav_path = os.path.join(CLEAN_DIR, clean_id+".wav")
    
    # --- THIS FIX IS RETAINED ---
    # Find the corresponding .wav file, not just the first one
    base_name = f.replace("_noisy_mag.npy", "")
    noisy_wav_path = os.path.join(MIX_DIR, base_name + ".wav")
    # --- END FIX ---

    enhanced_wav_path = reconstruct_audio(trained_model,noisy_path,phase_path,ENHANCED_DIR)
    
    # Load original clean for reference
    clean_orig, _ = librosa.load(clean_wav_path, sr=16000)
    
    # Load the (correctly mixed) noisy file
    noisy, _ = librosa.load(noisy_wav_path, sr=16000)
    
    # Load the enhanced file
    enhanced, _ = librosa.load(enhanced_wav_path, sr=16000)
    
    # --- ‚ö†Ô∏è EVALUATION FIX ---
    # We must compare all signals to the *normalized* clean signal,
    # since that is what the model was trained on and what the
    # noisy signal was built from.
    min_len = min(len(clean_orig), len(noisy), len(enhanced))
    clean_orig = clean_orig[:min_len]
    noisy = noisy[:min_len]
    enhanced = enhanced[:min_len]
    
    # Use the same normalization that was used for training
    clean_normalized = clean_orig / (np.max(np.abs(clean_orig)) + 1e-8)
    
    # Now, calculate SNR using this aligned, normalized clean signal
    in_snr = calculate_snr_proper(clean_normalized, noisy)
    out_snr = calculate_snr_proper(clean_normalized, enhanced)
    # --- END EVALUATION FIX ---


    print(f"\nüéß Sample {idx+1}: {clean_id} ({base_name})")
    print(f"Input SNR: {in_snr:.2f} dB | Output SNR: {out_snr:.2f} dB | Œî: {out_snr-in_snr:.2f} dB")

    plt.figure(figsize=(15,3))
    # Plot the normalized clean audio for a true comparison
    plt.plot(clean_normalized, label="Clean (Normalized)", color='green')
    plt.plot(noisy, label=f"Noisy ({in_snr:.2f} dB)", color='red')
    plt.plot(enhanced, label=f"Denoised ({out_snr:.2f} dB)", color='blue')
    plt.legend(); plt.grid(True); plt.show()
    
    # Display audio (using original clean for listening)
    display(ipd.Audio(clean_orig, rate=16000))
    display(ipd.Audio(noisy, rate=16000))
    display(ipd.Audio(enhanced, rate=16000))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import librosa
import librosa.display
from IPython.display import display, Audio
import os

def plot_individual_signals(clean, noisy, enhanced, sr=16000, title_suffix=""):
    """Plot clean, noisy, and enhanced signals separately"""
    
    # Time vector for x-axis
    time = np.arange(len(clean)) / sr
    
    # 1. Clean Signal Plot
    plt.figure(figsize=(15, 4))
    plt.plot(time, clean, color='green', linewidth=1.5, alpha=0.8)
    plt.title(f"Clean Speech Signal {title_suffix}", fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    plt.show()
    
    # 2. Noisy Signal Plot
    plt.figure(figsize=(15, 4))
    plt.plot(time, noisy, color='red', linewidth=1.5, alpha=0.8)
    current_snr = calculate_snr_proper(clean, noisy)
    plt.title(f"Noisy Signal {title_suffix} (SNR: {current_snr:.2f} dB)", 
              fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    plt.show()
    
    # 3. Enhanced Signal Plot
    plt.figure(figsize=(15, 4))
    plt.plot(time, enhanced, color='blue', linewidth=1.5, alpha=0.8)
    current_snr = calculate_snr_proper(clean, enhanced)
    plt.title(f"Enhanced Signal {title_suffix} (SNR: {current_snr:.2f} dB)", 
              fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    plt.show()

def plot_spectrogram_comparison(clean, noisy, enhanced, sr=16000, title_suffix=""):
    """Plot spectrograms for all three signals in subplots"""
    
    fig, axes = plt.subplots(3, 1, figsize=(16, 12))
    
    # Spectrogram parameters
    n_fft = 1024
    hop_length = 256
    
    # Clean spectrogram
    clean_spec = librosa.stft(clean, n_fft=n_fft, hop_length=hop_length)
    clean_db = librosa.amplitude_to_db(np.abs(clean_spec), ref=np.max)
    img1 = librosa.display.specshow(clean_db, sr=sr, hop_length=hop_length, 
                                   x_axis='time', y_axis='log', ax=axes[0])
    axes[0].set_title(f"Clean Spectrogram {title_suffix}", fontsize=12, fontweight='bold')
    axes[0].set_ylabel("Frequency (Hz)")
    plt.colorbar(img1, ax=axes[0])
    
    # Noisy spectrogram
    noisy_spec = librosa.stft(noisy, n_fft=n_fft, hop_length=hop_length)
    noisy_db = librosa.amplitude_to_db(np.abs(noisy_spec), ref=np.max)
    img2 = librosa.display.specshow(noisy_db, sr=sr, hop_length=hop_length, 
                                   x_axis='time', y_axis='log', ax=axes[1])
    axes[1].set_title(f"Noisy Spectrogram {title_suffix} (SNR: {calculate_snr_proper(clean, noisy):.2f} dB)", 
                     fontsize=12, fontweight='bold')
    axes[1].set_ylabel("Frequency (Hz)")
    plt.colorbar(img2, ax=axes[1])
    
    # Enhanced spectrogram
    enhanced_spec = librosa.stft(enhanced, n_fft=n_fft, hop_length=hop_length)
    enhanced_db = librosa.amplitude_to_db(np.abs(enhanced_spec), ref=np.max)
    img3 = librosa.display.specshow(enhanced_db, sr=sr, hop_length=hop_length, 
                                   x_axis='time', y_axis='log', ax=axes[2])
    axes[2].set_title(f"Enhanced Spectrogram {title_suffix} (SNR: {calculate_snr_proper(clean, enhanced):.2f} dB)", 
                     fontsize=12, fontweight='bold')
    axes[2].set_ylabel("Frequency (Hz)")
    axes[2].set_xlabel("Time (s)")
    plt.colorbar(img3, ax=axes[2])
    
    plt.tight_layout()
    plt.show()

def plot_waveform_comparison(clean, noisy, enhanced, sr=16000, title_suffix=""):
    """Plot all three signals together for comparison"""
    
    time = np.arange(len(clean)) / sr
    
    plt.figure(figsize=(16, 6))
    
    plt.plot(time, clean, label='Clean', color='green', linewidth=2, alpha=0.8)
    plt.plot(time, noisy, label=f'Noisy (SNR: {calculate_snr_proper(clean, noisy):.2f} dB)', 
             color='red', linewidth=1.5, alpha=0.7)
    plt.plot(time, enhanced, label=f'Enhanced (SNR: {calculate_snr_proper(clean, enhanced):.2f} dB)', 
             color='blue', linewidth=1.5, alpha=0.8)
    
    plt.title(f"Waveform Comparison {title_suffix}", fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    plt.show()

def plot_magnitude_spectrum(clean, noisy, enhanced, sr=16000, title_suffix=""):
    """Plot magnitude spectrum comparison"""
    
    # Compute FFT
    clean_fft = np.fft.fft(clean)
    noisy_fft = np.fft.fft(noisy)
    enhanced_fft = np.fft.fft(enhanced)
    
    # Frequency axis
    freqs = np.fft.fftfreq(len(clean), 1/sr)
    positive_freq_idx = (freqs > 0) & (freqs <= 8000)  # Focus on speech frequencies
    
    plt.figure(figsize=(14, 6))
    
    # Plot magnitude spectrum
    plt.plot(freqs[positive_freq_idx], np.abs(clean_fft[positive_freq_idx]), 
             label='Clean', color='green', linewidth=2, alpha=0.8)
    plt.plot(freqs[positive_freq_idx], np.abs(noisy_fft[positive_freq_idx]), 
             label='Noisy', color='red', linewidth=1.5, alpha=0.7)
    plt.plot(freqs[positive_freq_idx], np.abs(enhanced_fft[positive_freq_idx]), 
             label='Enhanced', color='blue', linewidth=1.5, alpha=0.8)
    
    plt.title(f"Magnitude Spectrum Comparison {title_suffix}", fontsize=14, fontweight='bold')
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Magnitude")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def plot_snr_improvement_bar(snr_results):
    """Plot SNR improvement as bar chart"""
    
    samples = [f"Sample {i+1}" for i in range(len(snr_results))]
    input_snrs = [result['input_snr'] for result in snr_results]
    output_snrs = [result['output_snr'] for result in snr_results]
    improvements = [result['improvement'] for result in snr_results]
    
    x = np.arange(len(samples))
    width = 0.25
    
    plt.figure(figsize=(12, 6))
    
    plt.bar(x - width, input_snrs, width, label='Input SNR', color='red', alpha=0.7)
    plt.bar(x, output_snrs, width, label='Output SNR', color='blue', alpha=0.7)
    plt.bar(x + width, improvements, width, label='SNR Improvement', color='green', alpha=0.7)
    
    plt.xlabel('Samples')
    plt.ylabel('SNR (dB)')
    plt.title('SNR Improvement Across Samples', fontsize=14, fontweight='bold')
    plt.xticks(x, samples)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for i, (in_snr, out_snr, imp) in enumerate(zip(input_snrs, output_snrs, improvements)):
        plt.text(i - width, in_snr + 0.5, f'{in_snr:.1f}', ha='center', va='bottom', fontweight='bold')
        plt.text(i, out_snr + 0.5, f'{out_snr:.1f}', ha='center', va='bottom', fontweight='bold')
        plt.text(i + width, imp + 0.5, f'+{imp:.1f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Enhanced evaluation function
def comprehensive_evaluation(model, num_samples=5):
    """Run comprehensive evaluation with all visualization types"""
    
    files = sorted(os.listdir(noisy_dir_stft))[:num_samples]
    snr_results = []
    
    print(f"üîä Evaluating {len(files)} samples...")
    
    for idx, f in enumerate(files):
        clean_id = f.split("_")[0]
        noisy_path = os.path.join(noisy_dir_stft, f)
        phase_path = os.path.join(STFT_DIR, "phase", f.replace("_noisy_mag.npy", "_phase.npy"))
        clean_wav_path = os.path.join(CLEAN_DIR, clean_id + ".wav")
        
        base_name = f.replace("_noisy_mag.npy", "")
        noisy_wav_path = os.path.join(MIX_DIR, base_name + ".wav")
        
        print(f"\n{'='*60}")
        print(f"üéß Processing Sample {idx+1}/{len(files)}: {base_name}")
        print(f"{'='*60}")
        
        # Reconstruct enhanced audio
        enhanced_wav_path = reconstruct_audio(model, noisy_path, phase_path, ENHANCED_DIR)
        
        # Load all audio files
        clean_orig, _ = librosa.load(clean_wav_path, sr=16000)
        noisy, _ = librosa.load(noisy_wav_path, sr=16000)
        enhanced, _ = librosa.load(enhanced_wav_path, sr=16000)
        
        # Align lengths and normalize
        min_len = min(len(clean_orig), len(noisy), len(enhanced))
        clean_orig = clean_orig[:min_len]
        noisy = noisy[:min_len]
        enhanced = enhanced[:min_len]
        
        clean_normalized = clean_orig / (np.max(np.abs(clean_orig)) + 1e-8)
        
        # Calculate SNR
        in_snr = calculate_snr_proper(clean_normalized, noisy)
        out_snr = calculate_snr_proper(clean_normalized, enhanced)
        snr_improvement = out_snr - in_snr
        
        snr_results.append({
            'sample': base_name,
            'input_snr': in_snr,
            'output_snr': out_snr,
            'improvement': snr_improvement
        })
        
        print(f"üìä Results:")
        print(f"   ‚Ä¢ Input SNR: {in_snr:.2f} dB")
        print(f"   ‚Ä¢ Output SNR: {out_snr:.2f} dB")
        print(f"   ‚Ä¢ SNR Improvement: {snr_improvement:+.2f} dB")
        
        # Generate all plots for this sample
        plot_individual_signals(clean_normalized, noisy, enhanced, title_suffix=f"- {base_name}")
        plot_waveform_comparison(clean_normalized, noisy, enhanced, title_suffix=f"- {base_name}")
        plot_spectrogram_comparison(clean_normalized, noisy, enhanced, title_suffix=f"- {base_name}")
        plot_magnitude_spectrum(clean_normalized, noisy, enhanced, title_suffix=f"- {base_name}")
        
        # Display audio players
        print("\nüîä Audio Playback:")
        display(Audio(clean_orig, rate=16000))
        display(Audio(noisy, rate=16000))
        display(Audio(enhanced, rate=16000))
    
    # Summary plots
    if len(snr_results) > 1:
        print(f"\n{'='*60}")
        print("üìà SUMMARY STATISTICS")
        print(f"{'='*60}")
        
        avg_improvement = np.mean([r['improvement'] for r in snr_results])
        max_improvement = np.max([r['improvement'] for r in snr_results])
        min_improvement = np.min([r['improvement'] for r in snr_results])
        
        print(f"Average SNR Improvement: {avg_improvement:.2f} dB")
        print(f"Maximum SNR Improvement: {max_improvement:.2f} dB")
        print(f"Minimum SNR Improvement: {min_improvement:.2f} dB")
        
        plot_snr_improvement_bar(snr_results)
    
    return snr_results

# Run the comprehensive evaluation
print("Starting comprehensive evaluation...")
snr_results = comprehensive_evaluation(trained_model, num_samples=5)  # Change to 5 samples

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

# Create directory for saving plots
PLOTS_DIR = "/kaggle/working/plots"
os.makedirs(PLOTS_DIR, exist_ok=True)
print(f"üìÅ Plot directory created: {PLOTS_DIR}")

def plot_individual_signals(clean, noisy, enhanced, sr=16000, title_suffix="", save_plots=True):
    """Plot clean, noisy, and enhanced signals separately and save as PNG"""
    
    # Time vector for x-axis
    time = np.arange(len(clean)) / sr
    
    # 1. Clean Signal Plot
    plt.figure(figsize=(15, 4))
    plt.plot(time, clean, color='green', linewidth=1.5, alpha=0.8)
    plt.title(f"Clean Speech Signal {title_suffix}", fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    
    if save_plots:
        clean_filename = f"clean_signal_{title_suffix.replace(' ', '_').replace('-', '_')}.png"
        clean_filepath = os.path.join(PLOTS_DIR, clean_filename)
        plt.savefig(clean_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {clean_filename}")
    plt.show()
    plt.close()
    
    # 2. Noisy Signal Plot
    plt.figure(figsize=(15, 4))
    plt.plot(time, noisy, color='red', linewidth=1.5, alpha=0.8)
    current_snr = calculate_snr_proper(clean, noisy)
    plt.title(f"Noisy Signal {title_suffix} (SNR: {current_snr:.2f} dB)", 
              fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    
    if save_plots:
        noisy_filename = f"noisy_signal_{title_suffix.replace(' ', '_').replace('-', '_')}.png"
        noisy_filepath = os.path.join(PLOTS_DIR, noisy_filename)
        plt.savefig(noisy_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {noisy_filename}")
    plt.show()
    plt.close()
    
    # 3. Enhanced Signal Plot
    plt.figure(figsize=(15, 4))
    plt.plot(time, enhanced, color='blue', linewidth=1.5, alpha=0.8)
    current_snr = calculate_snr_proper(clean, enhanced)
    plt.title(f"Enhanced Signal {title_suffix} (SNR: {current_snr:.2f} dB)", 
              fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    
    if save_plots:
        enhanced_filename = f"enhanced_signal_{title_suffix.replace(' ', '_').replace('-', '_')}.png"
        enhanced_filepath = os.path.join(PLOTS_DIR, enhanced_filename)
        plt.savefig(enhanced_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {enhanced_filename}")
    plt.show()
    plt.close()

def plot_spectrogram_comparison(clean, noisy, enhanced, sr=16000, title_suffix="", save_plots=True):
    """Plot spectrograms for all three signals in subplots and save as PNG"""
    
    fig, axes = plt.subplots(3, 1, figsize=(16, 12))
    
    # Spectrogram parameters
    n_fft = 1024
    hop_length = 256
    
    # Clean spectrogram
    clean_spec = librosa.stft(clean, n_fft=n_fft, hop_length=hop_length)
    clean_db = librosa.amplitude_to_db(np.abs(clean_spec), ref=np.max)
    img1 = librosa.display.specshow(clean_db, sr=sr, hop_length=hop_length, 
                                   x_axis='time', y_axis='log', ax=axes[0])
    axes[0].set_title(f"Clean Spectrogram {title_suffix}", fontsize=12, fontweight='bold')
    axes[0].set_ylabel("Frequency (Hz)")
    plt.colorbar(img1, ax=axes[0])
    
    # Noisy spectrogram
    noisy_spec = librosa.stft(noisy, n_fft=n_fft, hop_length=hop_length)
    noisy_db = librosa.amplitude_to_db(np.abs(noisy_spec), ref=np.max)
    img2 = librosa.display.specshow(noisy_db, sr=sr, hop_length=hop_length, 
                                   x_axis='time', y_axis='log', ax=axes[1])
    axes[1].set_title(f"Noisy Spectrogram {title_suffix} (SNR: {calculate_snr_proper(clean, noisy):.2f} dB)", 
                     fontsize=12, fontweight='bold')
    axes[1].set_ylabel("Frequency (Hz)")
    plt.colorbar(img2, ax=axes[1])
    
    # Enhanced spectrogram
    enhanced_spec = librosa.stft(enhanced, n_fft=n_fft, hop_length=hop_length)
    enhanced_db = librosa.amplitude_to_db(np.abs(enhanced_spec), ref=np.max)
    img3 = librosa.display.specshow(enhanced_db, sr=sr, hop_length=hop_length, 
                                   x_axis='time', y_axis='log', ax=axes[2])
    axes[2].set_title(f"Enhanced Spectrogram {title_suffix} (SNR: {calculate_snr_proper(clean, enhanced):.2f} dB)", 
                     fontsize=12, fontweight='bold')
    axes[2].set_ylabel("Frequency (Hz)")
    axes[2].set_xlabel("Time (s)")
    plt.colorbar(img3, ax=axes[2])
    
    plt.tight_layout()
    
    if save_plots:
        spec_filename = f"spectrogram_comparison_{title_suffix.replace(' ', '_').replace('-', '_')}.png"
        spec_filepath = os.path.join(PLOTS_DIR, spec_filename)
        plt.savefig(spec_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {spec_filename}")
    
    plt.show()
    plt.close()

def plot_waveform_comparison(clean, noisy, enhanced, sr=16000, title_suffix="", save_plots=True):
    """Plot all three signals together for comparison and save as PNG"""
    
    time = np.arange(len(clean)) / sr
    
    plt.figure(figsize=(16, 6))
    
    plt.plot(time, clean, label='Clean', color='green', linewidth=2, alpha=0.8)
    plt.plot(time, noisy, label=f'Noisy (SNR: {calculate_snr_proper(clean, noisy):.2f} dB)', 
             color='red', linewidth=1.5, alpha=0.7)
    plt.plot(time, enhanced, label=f'Enhanced (SNR: {calculate_snr_proper(clean, enhanced):.2f} dB)', 
             color='blue', linewidth=1.5, alpha=0.8)
    
    plt.title(f"Waveform Comparison {title_suffix}", fontsize=14, fontweight='bold')
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(0, time[-1])
    plt.tight_layout()
    
    if save_plots:
        waveform_filename = f"waveform_comparison_{title_suffix.replace(' ', '_').replace('-', '_')}.png"
        waveform_filepath = os.path.join(PLOTS_DIR, waveform_filename)
        plt.savefig(waveform_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {waveform_filename}")
    
    plt.show()
    plt.close()

def plot_magnitude_spectrum(clean, noisy, enhanced, sr=16000, title_suffix="", save_plots=True):
    """Plot magnitude spectrum comparison and save as PNG"""
    
    # Compute FFT
    clean_fft = np.fft.fft(clean)
    noisy_fft = np.fft.fft(noisy)
    enhanced_fft = np.fft.fft(enhanced)
    
    # Frequency axis
    freqs = np.fft.fftfreq(len(clean), 1/sr)
    positive_freq_idx = (freqs > 0) & (freqs <= 8000)  # Focus on speech frequencies
    
    plt.figure(figsize=(14, 6))
    
    # Plot magnitude spectrum
    plt.plot(freqs[positive_freq_idx], np.abs(clean_fft[positive_freq_idx]), 
             label='Clean', color='green', linewidth=2, alpha=0.8)
    plt.plot(freqs[positive_freq_idx], np.abs(noisy_fft[positive_freq_idx]), 
             label='Noisy', color='red', linewidth=1.5, alpha=0.7)
    plt.plot(freqs[positive_freq_idx], np.abs(enhanced_fft[positive_freq_idx]), 
             label='Enhanced', color='blue', linewidth=1.5, alpha=0.8)
    
    plt.title(f"Magnitude Spectrum Comparison {title_suffix}", fontsize=14, fontweight='bold')
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Magnitude")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_plots:
        spectrum_filename = f"magnitude_spectrum_{title_suffix.replace(' ', '_').replace('-', '_')}.png"
        spectrum_filepath = os.path.join(PLOTS_DIR, spectrum_filename)
        plt.savefig(spectrum_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {spectrum_filename}")
    
    plt.show()
    plt.close()

def plot_snr_improvement_bar(snr_results, save_plots=True):
    """Plot SNR improvement as bar chart and save as PNG"""
    
    samples = [f"Sample {i+1}" for i in range(len(snr_results))]
    input_snrs = [result['input_snr'] for result in snr_results]
    output_snrs = [result['output_snr'] for result in snr_results]
    improvements = [result['improvement'] for result in snr_results]
    
    x = np.arange(len(samples))
    width = 0.25
    
    plt.figure(figsize=(12, 6))
    
    plt.bar(x - width, input_snrs, width, label='Input SNR', color='red', alpha=0.7)
    plt.bar(x, output_snrs, width, label='Output SNR', color='blue', alpha=0.7)
    plt.bar(x + width, improvements, width, label='SNR Improvement', color='green', alpha=0.7)
    
    plt.xlabel('Samples')
    plt.ylabel('SNR (dB)')
    plt.title('SNR Improvement Across Samples', fontsize=14, fontweight='bold')
    plt.xticks(x, samples)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for i, (in_snr, out_snr, imp) in enumerate(zip(input_snrs, output_snrs, improvements)):
        plt.text(i - width, in_snr + 0.5, f'{in_snr:.1f}', ha='center', va='bottom', fontweight='bold')
        plt.text(i, out_snr + 0.5, f'{out_snr:.1f}', ha='center', va='bottom', fontweight='bold')
        plt.text(i + width, imp + 0.5, f'+{imp:.1f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    
    if save_plots:
        snr_filename = "snr_improvement_summary.png"
        snr_filepath = os.path.join(PLOTS_DIR, snr_filename)
        plt.savefig(snr_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {snr_filename}")
    
    plt.show()
    plt.close()

In [None]:
def comprehensive_evaluation_with_saving(model, num_samples=5, save_plots=True):
    """Run comprehensive evaluation with automatic plot saving"""
    
    files = sorted(os.listdir(noisy_dir_stft))[:num_samples]
    snr_results = []
    
    print(f"üîä Evaluating {len(files)} samples...")
    print(f"üìÅ Saving plots to: {PLOTS_DIR}")
    
    for idx, f in enumerate(files):
        clean_id = f.split("_")[0]
        noisy_path = os.path.join(noisy_dir_stft, f)
        phase_path = os.path.join(STFT_DIR, "phase", f.replace("_noisy_mag.npy", "_phase.npy"))
        clean_wav_path = os.path.join(CLEAN_DIR, clean_id + ".wav")
        
        base_name = f.replace("_noisy_mag.npy", "")
        noisy_wav_path = os.path.join(MIX_DIR, base_name + ".wav")
        
        print(f"\n{'='*60}")
        print(f"üéß Processing Sample {idx+1}/{len(files)}: {base_name}")
        print(f"{'='*60}")
        
        # Reconstruct enhanced audio
        enhanced_wav_path = reconstruct_audio(model, noisy_path, phase_path, ENHANCED_DIR)
        
        # Load all audio files
        clean_orig, _ = librosa.load(clean_wav_path, sr=16000)
        noisy, _ = librosa.load(noisy_wav_path, sr=16000)
        enhanced, _ = librosa.load(enhanced_wav_path, sr=16000)
        
        # Align lengths and normalize
        min_len = min(len(clean_orig), len(noisy), len(enhanced))
        clean_orig = clean_orig[:min_len]
        noisy = noisy[:min_len]
        enhanced = enhanced[:min_len]
        
        clean_normalized = clean_orig / (np.max(np.abs(clean_orig)) + 1e-8)
        
        # Calculate SNR
        in_snr = calculate_snr_proper(clean_normalized, noisy)
        out_snr = calculate_snr_proper(clean_normalized, enhanced)
        snr_improvement = out_snr - in_snr
        
        snr_results.append({
            'sample': base_name,
            'input_snr': in_snr,
            'output_snr': out_snr,
            'improvement': snr_improvement
        })
        
        print(f"üìä Results:")
        print(f"   ‚Ä¢ Input SNR: {in_snr:.2f} dB")
        print(f"   ‚Ä¢ Output SNR: {out_snr:.2f} dB")
        print(f"   ‚Ä¢ SNR Improvement: {snr_improvement:+.2f} dB")
        
        # Clean up the title suffix for filenames
        clean_title_suffix = base_name.replace(" ", "_").replace("-", "_")
        
        # Generate all plots for this sample with saving
        plot_individual_signals(clean_normalized, noisy, enhanced, 
                               title_suffix=clean_title_suffix, save_plots=save_plots)
        plot_waveform_comparison(clean_normalized, noisy, enhanced, 
                                title_suffix=clean_title_suffix, save_plots=save_plots)
        plot_spectrogram_comparison(clean_normalized, noisy, enhanced, 
                                   title_suffix=clean_title_suffix, save_plots=save_plots)
        plot_magnitude_spectrum(clean_normalized, noisy, enhanced, 
                               title_suffix=clean_title_suffix, save_plots=save_plots)
        
        # Display audio players
        print("\nüîä Audio Playback:")
        display(Audio(clean_orig, rate=16000))
        display(Audio(noisy, rate=16000))
        display(Audio(enhanced, rate=16000))
    
    # Summary plots
    if len(snr_results) > 1:
        print(f"\n{'='*60}")
        print("üìà SUMMARY STATISTICS")
        print(f"{'='*60}")
        
        avg_improvement = np.mean([r['improvement'] for r in snr_results])
        max_improvement = np.max([r['improvement'] for r in snr_results])
        min_improvement = np.min([r['improvement'] for r in snr_results])
        
        print(f"Average SNR Improvement: {avg_improvement:.2f} dB")
        print(f"Maximum SNR Improvement: {max_improvement:.2f} dB")
        print(f"Minimum SNR Improvement: {min_improvement:.2f} dB")
        
        plot_snr_improvement_bar(snr_results, save_plots=save_plots)
    
    # Print summary of saved files
    if save_plots:
        saved_files = os.listdir(PLOTS_DIR)
        print(f"\nüìÇ Total plots saved: {len(saved_files)}")
        print("üìä File breakdown:")
        png_files = [f for f in saved_files if f.endswith('.png')]
        for file_type in ['clean_signal', 'noisy_signal', 'enhanced_signal', 
                         'waveform_comparison', 'spectrogram_comparison', 'magnitude_spectrum']:
            count = len([f for f in png_files if f.startswith(file_type)])
            if count > 0:
                print(f"   ‚Ä¢ {file_type}: {count} files")
    
    return snr_results

In [None]:
def plot_training_progress_with_saving(train_losses, val_losses, save_plots=True):
    """Plot training progress and save as PNG"""
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    epochs = range(1, len(train_losses) + 1)
    
    # 1. Loss curves
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Progress', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Log scale loss
    ax2.semilogy(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    ax2.semilogy(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss (log scale)')
    ax2.set_title('Training Progress (Log Scale)', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Loss difference
    loss_diff = [val - train for train, val in zip(train_losses, val_losses)]
    ax3.plot(epochs, loss_diff, 'g-', linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Validation - Training Loss')
    ax3.set_title('Generalization Gap', fontweight='bold')
    ax3.grid(True, alpha=0.3)
    
    # 4. Moving average
    window = 5
    if len(train_losses) > window:
        train_ma = np.convolve(train_losses, np.ones(window)/window, mode='valid')
        val_ma = np.convolve(val_losses, np.ones(window)/window, mode='valid')
        ax4.plot(epochs[window-1:], train_ma, 'b-', label='Train MA', linewidth=2)
        ax4.plot(epochs[window-1:], val_ma, 'r-', label='Val MA', linewidth=2)
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Loss (Moving Average)')
        ax4.set_title(f'Moving Average (Window={window})', fontweight='bold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_plots:
        training_filename = "training_progress_analysis.png"
        training_filepath = os.path.join(PLOTS_DIR, training_filename)
        plt.savefig(training_filepath, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {training_filename}")
    
    plt.show()
    plt.close()

In [None]:
# Run the comprehensive evaluation with plot saving
print("üöÄ Starting comprehensive evaluation with plot saving...")
snr_results = comprehensive_evaluation_with_saving(trained_model, num_samples=5, save_plots=True)

# If you have training history, save training plots too
try:
    plot_training_progress_with_saving(train_losses, val_losses, save_plots=True)
except NameError:
    print("‚ö†Ô∏è Training history not available for plotting")

# List all saved files
print(f"\nüéâ Evaluation complete! All plots saved to: {PLOTS_DIR}")
saved_files = sorted(os.listdir(PLOTS_DIR))
print(f"üìÑ Total PNG files created: {len([f for f in saved_files if f.endswith('.png')])}")

# Show file listing
print("\nüìã Saved files:")
for i, file in enumerate(saved_files[:20]):  # Show first 20 files
    print(f"   {i+1:2d}. {file}")
if len(saved_files) > 20:
    print(f"   ... and {len(saved_files) - 20} more files")

In [None]:
import numpy as np
import librosa

# Load the signals

clean_path = "/kaggle/input/datasetnew/clean_wav/clean_wav/03-01-01-01-01-01-02.wav"
noisy_path = "/kaggle/working/mixed_snr_wav/03-01-01-01-01-01-02_ambience_snr-5dB.wav"
enhanced_path = "/kaggle/working/enhanced_audio/03-01-01-01-01-01-02_ambience_snr-5dB_enhanced.wav"

sr = 16000
clean, _ = librosa.load(clean_path, sr=sr)
noisy, _ = librosa.load(noisy_path, sr=sr)
enhanced, _ = librosa.load(enhanced_path, sr=sr)

# Trim signals to same length
min_len = min(len(clean), len(noisy), len(enhanced))
clean = clean[:min_len]
noisy = noisy[:min_len]
enhanced = enhanced[:min_len]

# Normalize clean signal
clean_norm = clean / (np.max(np.abs(clean)) + 1e-8)

# Define simple SNR calculation
def snr_db(reference, estimate):
    noise = reference - estimate
    snr = 10 * np.log10(np.sum(reference**2) / (np.sum(noise**2) + 1e-8))
    return snr

# Compute SNRs
input_snr_check = snr_db(clean_norm, noisy)
output_snr_check = snr_db(clean_norm, enhanced)
improvement_check = output_snr_check - input_snr_check

print(f"Input SNR (manual check): {input_snr_check:.2f} dB")
print(f"Output SNR (manual check): {output_snr_check:.2f} dB")
print(f"SNR Improvement: {improvement_check:.2f} dB")
